// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/ir_adaptor/translator/op_translator.h"

#include <algorithm>
#include <cctype>
#include <numeric>
#include <regex>
#include <string>
#include <tuple>
#include <typeinfo>
#include <unordered_map>
#include <vector>

#include "paddle/common/ddim.h"
#include "paddle/common/enforce.h"
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/pir/include/core/attribute.h"
#include "paddle/pir/include/core/builder.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/ir_context.h"
#include "paddle/pir/include/core/operation.h"
#include "paddle/pir/include/core/utils.h"
#include "paddle/pir/include/core/value.h"
#include "paddle/utils/blank.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#endif
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/pir/dialect/CMakeLists.txt.
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"

namespace paddle::translator {

namespace {

using IdxInOp = size_t;
using IdxInVector = size_t;
using ResultIdx = std::tuple<IdxInOp, IdxInVector>;
using OpDesc = paddle::framework::OpDesc;
using BlockDesc = paddle::framework::BlockDesc;
using VarDesc = paddle::framework::VarDesc;
using OpOutputTypeList = std::vector<pir::Type>;
using OpOutputMapping = std::unordered_map<std::string, ResultIdx>;
using OpInputInfo = dialect::OpInputInfo;
using OpInputInfoList = std::vector<dialect::OpInputInfo>;
using OpAttributeInfo = dialect::OpAttributeInfo;
using OpAttributeInfoList = std::vector<dialect::OpAttributeInfo>;
using OpOutputInfo = dialect::OpOutputInfo;
using OpOutputInfoList = std::vector<dialect::OpOutputInfo>;
using InputHandlerFn = std::function<pir::Value(pir::IrContext*,
                                                TranslationContext*,
                                                const OpDesc&,
                                                const std::string&,
                                                const OpInputInfo&,
                                                pir::Block*)>;
using AttributeHandlerFn = std::function<pir::Attribute(
    pir::IrContext*, const OpDesc&, const OpAttributeInfo&)>;
using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage;
constexpr char kTargetDialectPrefix[] = "pd_op.";  // NOLINT
#ifdef PADDLE_WITH_DNNL
constexpr char kOneDNNTargetDialectPrefix[] = "onednn_op.";  // NOLINT
#endif
constexpr char kCustomOpDialectPrefix[] = "custom_op.";  // NOLINT
constexpr char kEmptyVarName[] = "@EMPTY@";              // NOLINT

static const std::unordered_set<std::string> SpecialNonInplaceOps = {};

static const std::unordered_set<std::string> SpecialInplaceOps = {
    "adagrad",
    "adam",
    "adamax",
    "adamw",
};

inline bool IsInplace(const OpDesc& op_desc) {
  if (SpecialNonInplaceOps.count(op_desc.Type())) {
    return false;
  }
  if (SpecialInplaceOps.count(op_desc.Type())) {
    return true;
  }
  bool inplace = false;
  auto input_names = op_desc.InputArgumentNames();
  auto output_names = op_desc.OutputArgumentNames();

  if (input_names.empty() || output_names.empty()) {
    return inplace;
  }

  std::vector<std::string> name_intersection;
  std::sort(input_names.begin(), input_names.end());
  std::sort(output_names.begin(), output_names.end());
  std::set_intersection(input_names.begin(),
                        input_names.end(),
                        output_names.begin(),
                        output_names.end(),
                        std::back_inserter(name_intersection));
  if (!name_intersection.empty()) {
    std::string redundant_variables = std::accumulate(
        std::next(name_intersection.begin()),
        name_intersection.end(),
        name_intersection[0],
        [](std::string a, std::string b) { return a + "," + b; });
    VLOG(4) << "Following variables occur both in inputs and outputs: "
            << redundant_variables;
    return true;
  }

  return inplace;
}

inline std::string OpNameCompatibleMapping(std::string op_name) {
  auto& op_normalizer = OpNameNormalizer::instance();
  return op_normalizer[op_name];
}
inline bool isSparseString(const std::string& str) {
  std::regex pattern("^sparse_[a-zA-Z0-9_]+$");
  return std::regex_match(str, pattern);
}

inline pir::Operation* InsertCombineOperationForTarget(
    pir::IrContext* ctx,
    TranslationContext* param_map,
    pir::Block* block,
    const std::vector<std::string>& args) {
  std::string combine_op_name(pir::CombineOp::name());
  pir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name);

  std::vector<pir::Value> src_values;
  std::vector<pir::Type> types_in_vec;
  for (const auto& arg_name : args) {
    auto defining_info = param_map->at(arg_name);
    src_values.push_back(defining_info.value);
    types_in_vec.push_back(defining_info.value.type());
  }
  pir::Type target_vec_type = pir::VectorType::get(ctx, types_in_vec);
  pir::Operation* operation =
      pir::Operation::Create(src_values, {}, {target_vec_type}, op_info);
  block->push_back(operation);
  return operation;
}

inline pir::Operation* InsertFullOperationForAttributeInput(
    pir::IrContext* ctx, pir::Block* block, pir::Attribute attr) {
  float data = 0.0f;
  phi::DataType dtype = phi::DataType::UNDEFINED;
  if (attr.isa<pir::FloatAttribute>()) {
    data = attr.dyn_cast<pir::FloatAttribute>().data();
    dtype = phi::DataType::FLOAT32;
  } else if (attr.isa<pir::DoubleAttribute>()) {
    data = static_cast<float>(attr.dyn_cast<pir::DoubleAttribute>().data());
    dtype = phi::DataType::FLOAT64;
  } else if (attr.isa<pir::Int32Attribute>()) {
    data = static_cast<float>(attr.dyn_cast<pir::Int32Attribute>().data());
    dtype = phi::DataType::INT32;
  } else if (attr.isa<pir::Int64Attribute>()) {
    data = static_cast<float>(attr.dyn_cast<pir::Int64Attribute>().data());
    dtype = phi::DataType::INT64;
  } else if (attr.isa<pir::BoolAttribute>()) {
    data = static_cast<float>(attr.dyn_cast<pir::BoolAttribute>().data());
    dtype = phi::DataType::BOOL;
  } else if (attr.isa<dialect::ScalarAttribute>()) {
    // TODO(phlrain) : need update here, downcast from double to float
    data = static_cast<float>(
        attr.dyn_cast<dialect::ScalarAttribute>().data().to<double>());
    dtype = phi::DataType::FLOAT64;
  }
  pir::Builder builder(ctx, block);
  dialect::FullOp full_op = builder.Build<dialect::FullOp>(
      std::vector<int64_t>{1}, data, dtype, phi::CPUPlace());

  return full_op.operation();
}

inline pir::Operation* InsertFullArrayOperationForAttributeInput(
    pir::IrContext* ctx, pir::Block* block, pir::Attribute attr) {
  PADDLE_ENFORCE_EQ(attr.isa<dialect::IntArrayAttribute>(),
                    true,
                    common::errors::InvalidArgument(
                        "Encounter non IntArray type when trying to "
                        "insert IntArray mutable attribute"));
  phi::IntArray int_array = attr.dyn_cast<dialect::IntArrayAttribute>().data();
  pir::Builder builder(ctx, block);
  dialect::FullIntArrayOp full_int_array_op =
      builder.Build<dialect::FullIntArrayOp>(
          int_array.GetData(), phi::DataType::INT64, phi::CPUPlace());
  return full_int_array_op.operation();
}

inline pir::Operation* InsertStackOperationForTarget(
    pir::IrContext* ctx,
    TranslationContext* param_map,
    pir::Block* block,
    const std::vector<std::string>& args,
    int axis = 0) {
  auto* combine_op =
      InsertCombineOperationForTarget(ctx, param_map, block, args);
  pir::Builder builder(ctx, block);
  dialect::StackOp stack_op =
      builder.Build<dialect::StackOp>(combine_op->result(0), axis);
  return stack_op.operation();
}

inline pir::Operation* InsertCreateArrayOp(pir::IrContext* ctx,
                                           pir::Block* block,
                                           const VarDesc* var) {
  pir::Builder builder(ctx, block);

  auto var_desc_dtype = var->GetDataType();
  phi::DataType phi_dtype = phi::TransToPhiDataType(var_desc_dtype);

  auto create_array_op = builder.Build<dialect::CreateArrayOp>(phi_dtype);
  return create_array_op.operation();
}

inline bool HasOpInfo(pir::IrContext* ctx,
                      const OpDesc& op_desc,
                      std::string prefix) {
  std::string target_op_name = prefix + OpNameCompatibleMapping(op_desc.Type());
  if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') {
    target_op_name += "_";
  }
  auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
  if (op_info) {
    return true;
  }
  return false;
}

inline std::string GetPrefix(pir::IrContext* ctx, const OpDesc& op_desc) {
  if (HasOpInfo(ctx, op_desc, kCustomOpDialectPrefix)) {
    return kCustomOpDialectPrefix;
  }
#ifdef PADDLE_WITH_DNNL
  if (op_desc.GetAttrIfExists<bool>("use_mkldnn") ||
      op_desc.GetAttrIfExists<bool>("use_onednn") ||
      paddle::dialect::IsOneDNNOnlyOp(op_desc.Type())) {
    if (!HasOpInfo(ctx, op_desc, kOneDNNTargetDialectPrefix)) {
      VLOG(3) << op_desc.Type()
              << "'s use_onednn == True, but PIR not support OneDNN for this "
                 "op right now.";
      return kTargetDialectPrefix;
    } else {
      return kOneDNNTargetDialectPrefix;
    }
  } else {
    return kTargetDialectPrefix;
  }
#else
  return kTargetDialectPrefix;
#endif
}
}  // namespace

pir::OpInfo OpTranscriber::LookUpOpInfo(pir::IrContext* ctx,
                                        const OpDesc& op_desc) {
  std::string target_op_name =
      GetPrefix(ctx, op_desc) + OpNameCompatibleMapping(op_desc.Type());
  if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') {
    target_op_name += "_";
  }
  VLOG(6) << "[op name normalizing]: " << op_desc.Type() << " to "
          << target_op_name;
  if (!paddle::dialect::HaveOpToMultiKernelsMap(
          OpNameCompatibleMapping(op_desc.Type()))) {
    auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      IR_THROW("Op %d should have corresponding OpInfo %d",
               op_desc.Type(),
               target_op_name);
    }
    return op_info;
  }
  if (paddle::dialect::HaveOpToMultiKernelsMap(
          OpNameCompatibleMapping(op_desc.Type())) &&
      isSparseString(op_desc.Type())) {
    std::map<std::string, std::vector<std::string>> inputs = op_desc.Inputs();
    std::vector<std::string> input_types;
    for (const auto& pair : inputs) {
      if (op_desc.Type() == "sparse_sum" || op_desc.Type() == "sparse_slice" ||
          op_desc.Type() == "sparse_reshape") {
        if (pair.first != "x") {
          continue;
        }
      }
      VarDesc* var_desc = op_desc.Block()->FindVarRecursive(pair.second[0]);
      PADDLE_ENFORCE_NE(
          var_desc,
          nullptr,
          common::errors::InvalidArgument("[Op:%s] Input %s should not be null",
                                          op_desc.Type(),
                                          pair.second[0]));
      if (var_desc->GetType() ==
          paddle::framework::proto::VarType::SPARSE_COO) {
        input_types.emplace_back("sparse_coo");
      } else if (var_desc->GetType() ==
                 paddle::framework::proto::VarType::SPARSE_CSR) {
        input_types.emplace_back("sparse_csr");
      } else if (var_desc->GetType() ==
                 paddle::framework::proto::VarType::DENSE_TENSOR) {
        input_types.emplace_back("dense");
      } else {
        PADDLE_THROW(common::errors::InvalidArgument(
            "Op %d only support dense tensor ,sparse_coo and sparse_csr, but "
            "not %d",
            op_desc.Type(),
            var_desc->GetType()));
      }
    }
    target_op_name = OpNameCompatibleMapping(op_desc.Type());
    auto sig_infos = paddle::dialect::SparseOpToPdOpsMapping(target_op_name);

    target_op_name = "";
    for (const auto& sig : sig_infos) {
      if (input_types.size() != sig.inputs.size()) {
        continue;
      }
      size_t i = 0;
      for (i = 0; i < input_types.size(); ++i) {
        if (input_types[i] == "") {
          continue;
        }
        if (input_types[i] != sig.inputs[i]) {
          break;
        }
      }
      if (i == input_types.size()) {
        target_op_name = sig.name;
        break;
      }
    }
    PADDLE_ENFORCE_EQ(!target_op_name.empty(),
                      true,
                      common::errors::InvalidArgument(
                          "Op %d should have corresponding OpInfo %d",
                          op_desc.Type(),
                          target_op_name));

    target_op_name = GetPrefix(ctx, op_desc) + target_op_name + "_sp";
    if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') {
      target_op_name += "_";
    }
    auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op %d should have corresponding OpInfo %d",
          op_desc.Type(),
          target_op_name));
    }

    return op_info;
  }

  // for selected rows kernel choose
  auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
  auto* op_info_concept =
      op_info.GetInterfaceImpl<dialect::OpYamlInfoInterface>();

  OpInputInfoList input_infos;
  OpAttributeInfoList attr_infos;
  OpOutputInfoList output_infos;
  std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) =
      op_info_concept->get_op_info_(op_info.name());

  auto& op_normalizer = OpNameNormalizer::instance();
  std::vector<std::string> need_inputs_sig;
  for (const auto& info : input_infos) {
    if (info.is_mutable_attribute) {
      continue;
    }
    std::string legacy_input_name =
        op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
    auto legacy_input_vars = op_desc.Input(legacy_input_name, true);
    PADDLE_ENFORCE_EQ(legacy_input_vars.size() <= 1,
                      true,
                      common::errors::InvalidArgument(
                          "Do not support duplicable tensor input, "
                          "when op have multi kernels. OP is %s.",
                          op_desc.Type()));

    if (legacy_input_vars.empty()) {
      need_inputs_sig.emplace_back("");
      continue;
    }
    VarDesc* var = op_desc.Block()->FindVarRecursive(legacy_input_vars[0]);
    PADDLE_ENFORCE_NE(
        var,
        nullptr,
        common::errors::InvalidArgument("[Op:%s] Input %s should not be null",
                                        op_desc.Type(),
                                        legacy_input_vars[0]));

    if (var->GetType() == paddle::framework::proto::VarType::DENSE_TENSOR) {
      need_inputs_sig.emplace_back("dense");
    } else if (var->GetType() ==
               paddle::framework::proto::VarType::SELECTED_ROWS) {
      need_inputs_sig.emplace_back("selected_rows");
    } else {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op %d only support dense tensor and selected_rows, but not %d",
          op_desc.Type(),
          var->GetType()));
    }
  }

  target_op_name = OpNameCompatibleMapping(op_desc.Type());

  auto sig_infos = paddle::dialect::LegacyOpToPdOpsMapping(target_op_name);

  target_op_name = "";
  for (const auto& sig : sig_infos) {
    if (need_inputs_sig.size() != sig.inputs.size()) {
      continue;
    }
    size_t i = 0;
    for (i = 0; i < need_inputs_sig.size(); ++i) {
      if (need_inputs_sig[i] == "") {
        continue;
      }
      if (need_inputs_sig[i] != sig.inputs[i]) {
        break;
      }
    }
    if (i == need_inputs_sig.size()) {
      target_op_name = sig.name;
      break;
    }
  }

  PADDLE_ENFORCE_EQ(!target_op_name.empty(),
                    true,
                    common::errors::InvalidArgument(
                        "Op %d should have corresponding OpInfo %d",
                        op_desc.Type(),
                        target_op_name));

  target_op_name = GetPrefix(ctx, op_desc) + target_op_name;
  if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') {
    target_op_name += "_";
  }
  if (!op_info) {
    PADDLE_THROW(common::errors::InvalidArgument(
        "Op %d should have corresponding OpInfo %d",
        op_desc.Type(),
        target_op_name));
  }

  return op_info;
}

void OpTranscriber::InsertSliceOperationForInput(
    pir::IrContext* ctx,
    TranslationContext* param_map,
    const OpDesc& op_desc,
    const OpInputInfoList& input_infos,
    pir::Block* block) {
  auto& op_normalizer = OpNameNormalizer::instance();
  std::set<std::string> yaml_input_set;
  for (const auto& info : input_infos) {
    std::string legacy_input_name =
        op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
    yaml_input_set.insert(legacy_input_name);
  }

  // scan all inputs to see if any of them is generated as a vector<Tensor>
  // so need an additional `SliceOp` to take it out.
  for (const auto& n : op_desc.Inputs()) {
    auto& args = n.second;

    for (const auto& arg_name : args) {
      bool check =
          param_map->count(arg_name) != 0 && !yaml_input_set.count(arg_name);
      if (!check) {
        continue;
      }
      auto defining_info = param_map->at(arg_name);
      if (defining_info.generated_by_vector) {
        InsertSliceOperationForTarget(
            ctx, param_map, block, defining_info, arg_name);
        VLOG(8) << "[op:" << op_desc.Type()
                << "] insert slice for var: " << arg_name;
      }
    }
  }
}

pir::Value OpTranscriber::GetAttributeAsInput(pir::IrContext* ctx,
                                              pir::Block* block,
                                              const OpDesc& op_desc,
                                              const OpInputInfo& input_info) {
  auto& attribute_translator = AttributeTranslator::instance();
  auto& op_normalizer = OpNameNormalizer::instance();

  auto legacy_attr_name =
      op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name);

  if (!op_desc.HasAttr(legacy_attr_name)) {
    PADDLE_THROW(
        common::errors::InvalidArgument("Op %s arg %s should not be zero size",
                                        op_desc.Type(),
                                        legacy_attr_name));
  }
  paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name);
  VLOG(10) << "[" << op_desc.Type() << "][attribute]"
           << " name: " << legacy_attr_name << " " << legacy_attr.index();
  pir::Attribute new_attr =
      attribute_translator(input_info.type_name, legacy_attr);

  pir::Operation* defining_op = nullptr;
  bool is_int_array = (input_info.type_name.find("IntArrayAttribute") !=
                       input_info.type_name.npos);
  if (is_int_array) {
    defining_op =
        InsertFullArrayOperationForAttributeInput(ctx, block, new_attr);
  } else {
    defining_op = InsertFullOperationForAttributeInput(ctx, block, new_attr);
  }

  return defining_op->result(0);
}

std::vector<pir::Value> OpTranscriber::GenerateOperationInput(
    pir::IrContext* ctx,
    TranslationContext* param_map,
    const OpDesc& op_desc,
    const std::string& normalized_op_name,
    const OpInputInfoList& input_infos,
    pir::Block* block) {
  VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance";

  auto& op_normalizer = OpNameNormalizer::instance();
  const auto* mutable_attributes =
      op_normalizer.GetMutableAttributes(op_desc.Type());

  VLOG(10) << "[op:" << op_desc.Type() << "][input] start";

  std::vector<pir::Value> op_inputs;

  for (const auto& info : input_infos) {
    if (auto special_handler = this->GetSpecialInputHandlers(info.name)) {
      pir::Value ret = special_handler(
          ctx, param_map, op_desc, normalized_op_name, info, block);
      op_inputs.push_back(ret);
      continue;
    }

    std::string legacy_input_name =
        op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);

    VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
             << legacy_input_name;

    std::vector<std::string> legacy_input_vars;
    // return empty Value if this arg is optional and not shown in OpDesc
    if (op_desc.HasInput(legacy_input_name, true)) {
      legacy_input_vars = op_desc.Input(legacy_input_name, true);
    }

    if (legacy_input_vars.empty()) {
      if (info.optional) {
        op_inputs.emplace_back(nullptr);
        continue;
      }
    }
    VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
             << legacy_input_name << " " << legacy_input_vars.size() << "["
             << legacy_input_vars << "]";

    if (legacy_input_vars.empty() && mutable_attributes != nullptr &&
        mutable_attributes->count(info.name) != 0) {
      const auto& candidate_var_names =
          op_normalizer.GetMutableAttributeInfos(op_desc.Type(), info.name);
      bool found_candidate_var = false;
      for (const auto& var_name : candidate_var_names) {
        VLOG(10) << "[handle mutable attribute][" << info.name << "]["
                 << var_name << "]";
        if (op_desc.HasInput(var_name)) {
          legacy_input_vars = op_desc.Input(var_name, true);
          if (legacy_input_vars.empty()) continue;
          found_candidate_var = true;
          break;
        }
      }

      if (!found_candidate_var) {
        auto attribute_input = GetAttributeAsInput(ctx, block, op_desc, info);
        op_inputs.push_back(attribute_input);
        continue;
      }
    }

    bool is_vector = (info.type_name.find("VectorType") != std::string::npos);
    is_vector |=
        (info.type_name.find("IntArrayAttribute") != std::string::npos);
    VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
             << is_vector << " " << info.type_name;
    // Specially process TensorArray, this because we cannot distinguish it with
    // Vector<DenseTensor> by other conditions but we cannot support it like
    // Vector<DenseTensor>
    if (legacy_input_vars.size() == 1) {
      VarDesc* var = op_desc.Block()->FindVarRecursive(legacy_input_vars[0]);
      PADDLE_ENFORCE_NE(
          var,
          nullptr,
          common::errors::InvalidArgument("[op:%s] Input %s should not be null",
                                          op_desc.Type(),
                                          legacy_input_vars[0]));
      if (var->GetType() ==
          paddle::framework::proto::VarType::DENSE_TENSOR_ARRAY) {
        is_vector = false;
      }
    }

    // if src type is Tensor
    if (!is_vector) {
      PADDLE_ENFORCE_EQ(legacy_input_vars.size(),
                        1UL,
                        common::errors::InvalidArgument(
                            "Input %s not found when parsing op %s",
                            info.name,
                            op_desc.Type()));
      PADDLE_ENFORCE_NE(param_map->count(legacy_input_vars[0]),
                        0UL,
                        common::errors::InvalidArgument(
                            "Input [%s: %s] of op [%s] not found in param map",
                            info.name,
                            legacy_input_vars[0],
                            op_desc.Type()));
      auto defining_info = (*param_map)[legacy_input_vars[0]];
      op_inputs.push_back(defining_info.value);

      // if src type is Vector<Tensor> , need an additional `CombineOp` to
      // assemble them.
    } else {
      auto* combine_op = InsertCombineOperationForTarget(
          ctx, param_map, block, legacy_input_vars);
      op_inputs.push_back(combine_op->result(0));
    }
  }

  return op_inputs;
}

std::tuple<OpOutputTypeList, OpOutputMapping>
OpTranscriber::GenerateOperationOutput(pir::IrContext* ctx,
                                       const OpDesc& op_desc,
                                       const OpOutputInfoList& output_infos) {
  OpOutputMapping arg_to_idx;
  OpOutputTypeList op_output_types = {};

  auto& type_translator = TypeTranslator::instance();
  auto& op_normalizer = OpNameNormalizer::instance();

  const BlockDesc* block = op_desc.Block();

  for (const auto& info : output_infos) {
    size_t cur_output_idx = op_output_types.size();
    std::string legacy_output_name =
        op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);

    VLOG(10) << "[op:" << op_desc.Type() << "][output]" << info.name << " "
             << legacy_output_name;

    // return empty type if this arg is optional and not shown in OpDesc
    if (!op_desc.HasOutput(legacy_output_name)) {
      VLOG(10) << "[output translating]"
               << "[" << op_desc.Type() << "] optional " << info.name << " :"
               << info.type_name << " " << legacy_output_name;
      PADDLE_ENFORCE_EQ(
          info.optional,
          true,
          common::errors::InvalidArgument(
              "Op %s arg %s should be optional if it can be empty",
              op_desc.Type(),
              legacy_output_name));
      op_output_types.emplace_back(nullptr);
      continue;
    }

    const auto& legacy_output_vars = op_desc.Output(legacy_output_name);
    bool is_vector = (info.type_name.find("VectorType") != std::string::npos);

    VLOG(10) << "[op:" << op_desc.Type() << "][output]" << info.name << " "
             << legacy_output_name << " " << legacy_output_vars.size() << " "
             << is_vector;

    // Specially process TensorArray, this because we cannot distinguish it with
    // Vector<DenseTensor> by other conditions but we cannot support it like
    // Vector<DenseTensor>
    if (legacy_output_vars.size() == 1) {
      VarDesc* var = block->FindVarRecursive(legacy_output_vars[0]);
      PADDLE_ENFORCE_NE(var,
                        nullptr,
                        common::errors::InvalidArgument(
                            "[op:%s] Output %s should not be null",
                            op_desc.Type(),
                            legacy_output_vars[0]));
      if (var->GetType() ==
          paddle::framework::proto::VarType::DENSE_TENSOR_ARRAY) {
        pir::Type translated_var_type =
            type_translator[var->GetType()](ctx, *var);
        op_output_types.push_back(translated_var_type);
        arg_to_idx[var->Name()] = {cur_output_idx, 0};
        continue;
      }
    }

    // if src type is Tensor
    if (!is_vector) {
      VLOG(10) << "[output translating]"
               << "[" << op_desc.Type() << "]" << info.name << " :"
               << info.type_name << " " << legacy_output_name << " "
               << legacy_output_vars.size();
      if (legacy_output_vars.empty()) {
        op_output_types.emplace_back(nullptr);
        continue;
      }

      auto& var_name = legacy_output_vars[0];
      VarDesc* var = block->FindVarRecursive(var_name);
      PADDLE_ENFORCE_NE(var,
                        nullptr,
                        common::errors::InvalidArgument(
                            "[op:%s] Output %s should not be null",
                            op_desc.Type(),
                            var_name));
      VLOG(10) << "[output translating]"
               << "[" << op_desc.Type() << "]" << info.name
               << " var: " << var_name << " type: " << var->GetType();

      pir::Type translated_var_type =
          type_translator[var->GetType()](ctx, *var);

      arg_to_idx[var_name] = {cur_output_idx, 0};
      op_output_types.push_back(translated_var_type);

      // if src type is Vector<Tensor>
    } else {
      VLOG(10) << "[output translating]"
               << "[" << op_desc.Type() << "]" << info.name << " :"
               << info.type_name << " var: " << legacy_output_name;
      std::vector<pir::Type> types;
      for (IdxInVector idx_in_vec = 0; idx_in_vec < legacy_output_vars.size();
           idx_in_vec++) {
        const auto& var_name = legacy_output_vars[idx_in_vec];
        if (var_name == kEmptyVarName) {
          types.emplace_back(nullptr);
          arg_to_idx[var_name] = {cur_output_idx, idx_in_vec};
          continue;
        }
        VarDesc* var = block->FindVarRecursive(var_name);
        PADDLE_ENFORCE_NE(var,
                          nullptr,
                          common::errors::InvalidArgument(
                              "[op:%s] Output %s should not be null",
                              op_desc.Type(),
                              var_name));
        VLOG(10) << "[output translating]"
                 << "[" << op_desc.Type() << "]" << info.name
                 << " var: " << var_name << " type: " << var->GetType();
        pir::Type translated_var_type =
            type_translator[var->GetType()](ctx, *var);
        types.push_back(translated_var_type);
        arg_to_idx[var_name] = {cur_output_idx, idx_in_vec};
      }
      pir::Type vec_type = pir::VectorType::get(ctx, types);
      op_output_types.push_back(vec_type);
    }
  }
  return {op_output_types, arg_to_idx};
}

static void TranslateOpDistAttribute(const OpDesc& op_desc,
                                     pir::AttributeMap* attr_map) {
  auto& attribute_translator = AttributeTranslator::instance();
  const paddle::framework::OperatorDistAttr* dist_attr = op_desc.DistAttr();
  if (dist_attr) {
    if (dist_attr->execution_stream() !=
        paddle::distributed::auto_parallel::kDefault) {
      pir::Attribute new_attr = attribute_translator(
          "execution_stream", dist_attr->execution_stream());
      (*attr_map)["execution_stream"] = new_attr;
    }
    if (dist_attr->stream_priority() != 0) {
      pir::Attribute new_attr =
          attribute_translator("stream_priority", dist_attr->stream_priority());
      (*attr_map)["stream_priority"] = new_attr;
    }
    if (dist_attr->scheduling_priority() != 0) {
      pir::Attribute new_attr = attribute_translator(
          "scheduling_priority", dist_attr->scheduling_priority());
      (*attr_map)["scheduling_priority"] = new_attr;
    }

    pir::Attribute force_record_event_attr = attribute_translator(
        "force_record_event", dist_attr->force_record_event());
    (*attr_map)["force_record_event"] = force_record_event_attr;

    pir::Attribute event_to_record_attr =
        attribute_translator("event_to_record", dist_attr->event_to_record());
    (*attr_map)["event_to_record"] = event_to_record_attr;

    pir::Attribute events_to_wait_attr =
        attribute_translator("events_to_wait", dist_attr->events_to_wait());
    (*attr_map)["events_to_wait"] = events_to_wait_attr;
  }
}

pir::AttributeMap OpTranscriber::TranslateOpAttribute(
    pir::IrContext* ctx,
    const std::string& normalized_op_name,
    const OpAttributeInfoList& op_attr_infos,
    const OpDesc& op_desc) {
  auto& attribute_translator = AttributeTranslator::instance();
  auto& op_normalizer = OpNameNormalizer::instance();
  pir::AttributeMap attribute_map = {};

  for (const auto& info : op_attr_infos) {
    if (auto handler = this->GetSpecialAttributeHandlers(info.name)) {
      auto new_attr = handler(ctx, op_desc, info);
      attribute_map[info.name] = new_attr;
      continue;
    }
    auto legacy_attr_name =
        op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
    VLOG(10) << "[op: " << op_desc.Type()
             << "][attr] from: " << legacy_attr_name << " to: " << info.name;
    if (op_desc.HasAttr(legacy_attr_name)) {
      paddle::framework::Attribute legacy_attr =
          op_desc.GetAttr(legacy_attr_name);
      VLOG(10) << "attribute in " << op_desc.Type()
               << " name: " << legacy_attr_name << " " << legacy_attr.index();
      pir::Attribute new_attr =
          attribute_translator(info.type_name, legacy_attr);
      attribute_map[info.name] = new_attr;
      if (!new_attr) {
        VLOG(0) << "empty attribute in " << op_desc.Type()
                << " name: " << info.name;
      }
    } else {
      VLOG(10) << "attribute in " << op_desc.Type()
               << " name: " << legacy_attr_name << " doesn't exist";
      this->HandleNonexistentAttribute(ctx, &attribute_map, info);
    }
  }

  return attribute_map;
}

void OpTranscriber::HandleNonexistentAttribute(pir::IrContext*,
                                               pir::AttributeMap* attribute_map,
                                               const OpAttributeInfo& info) {
  auto& attribute_translator = AttributeTranslator::instance();
  (*attribute_map)[info.name] =
      attribute_translator(info.type_name, paddle::framework::Attribute());
}

void OpTranscriber::RecordOpResultMapping(pir::IrContext* ctx,
                                          TranslationContext* param_map,
                                          const OpDesc& op_desc,
                                          pir::Operation* operation,
                                          const OpOutputMapping& arg_to_idx) {
  for (const auto& [arg_name, idx] : arg_to_idx) {
    const auto& [idx_in_op, idx_in_vec] = idx;
    VLOG(10) << "[output recording]"
             << "[" << op_desc.Type() << "]" << arg_name << " " << idx_in_op
             << " " << idx_in_vec;
    pir::Value value = operation->result(idx_in_op);
    bool generated_by_vector = value.type().isa<pir::VectorType>();

    param_map->PushValue(
        arg_name,
        VariableDefiningInfo(
            value,
            generated_by_vector,
            static_cast<int>(generated_by_vector ? idx_in_vec : -1)));
  }
}

pir::Operation* OpTranscriber::operator()(pir::IrContext* ctx,
                                          TranslationContext* param_map,
                                          const OpDesc& op_desc,
                                          pir::Block* block) {
  auto op_info = this->LookUpOpInfo(ctx, op_desc);
  auto* op_info_concept =
      op_info.GetInterfaceImpl<dialect::OpYamlInfoInterface>();

  OpInputInfoList input_infos;
  OpAttributeInfoList attr_infos;
  OpOutputInfoList output_infos;

  std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) =
      op_info_concept->get_op_info_(op_info.name());

  this->InsertSliceOperationForInput(
      ctx, param_map, op_desc, input_infos, block);

  auto op_inputs = this->GenerateOperationInput(
      ctx, param_map, op_desc, op_info.name(), input_infos, block);

  OpOutputMapping arg_to_idx;
  OpOutputTypeList op_output_types;
  std::tie(op_output_types, arg_to_idx) =
      this->GenerateOperationOutput(ctx, op_desc, output_infos);

  auto attribute_map =
      this->TranslateOpAttribute(ctx, op_info.name(), attr_infos, op_desc);
  TranslateOpDistAttribute(op_desc, &attribute_map);
  VLOG(4) << "[general op][" << op_desc.Type() << "] preparation end.";
  pir::Operation* operation = pir::Operation::Create(
      op_inputs, attribute_map, op_output_types, op_info);
  VLOG(4) << "[general op][" << op_desc.Type() << "] operation creation end.";
  block->push_back(operation);

  VLOG(4) << "[general op][" << op_desc.Type() << "] operation insertion end.";
  this->RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx);

  return operation;
}

struct ArgsortOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "stable") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    }
  }
};

struct Assign2AssignOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name;

    PADDLE_ENFORCE_EQ(op_desc.HasInput("X"),
                      true,
                      common::errors::InvalidArgument(
                          "op %s should have input `X`", op_desc.Type()));
    const auto& input_vars = op_desc.Input("X");
    PADDLE_ENFORCE_EQ(input_vars.size() == 1,
                      true,
                      common::errors::InvalidArgument(
                          "op %s should have one input `X`, but got %d.",
                          op_desc.Type(),
                          input_vars.size()));
    const auto* input_var = op_desc.Block()->FindVarRecursive(input_vars[0]);
    if (input_var->GetType() == framework::proto::VarType::DENSE_TENSOR_ARRAY) {
      target_op_name = dialect::AssignArray_Op::name();
    } else {
      return OpTranscriber::LookUpOpInfo(ctx, op_desc);
    }

    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op assign should have corresponding OpInfo %s.", target_op_name));
    }

    return op_info;
  }
};

struct Assign2AssignOutOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    const auto& op_info =
        ctx->GetRegisteredOpInfo(paddle::dialect::AssignOut_Op::name());
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op assign should have corresponding OpInfo %s.",
          paddle::dialect::AssignOut_Op::name()));
    }

    return op_info;
  }

  std::vector<pir::Value> GenerateOperationInput(
      pir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      pir::Block* block) override {
    std::vector<pir::Value> op_inputs;
    auto x_vars = op_desc.Input("X", true);
    auto x_defining_info = (*param_map)[x_vars[0]];
    op_inputs.push_back(x_defining_info.value);

    auto out_vars = op_desc.Output("Out");
    auto out_defining_info = (*param_map)[out_vars[0]];
    op_inputs.push_back(out_defining_info.value);

    return op_inputs;
  }
};

struct AssignOpTranscriber : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
    if (param_map->count(op_desc.Output("Out")[0])) {
      return Assign2AssignOutOpTranscriber()(ctx, param_map, op_desc, block);
    } else {
      return Assign2AssignOpTranscriber()(ctx, param_map, op_desc, block);
    }
  }
};

struct BatchNormOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "use_global_stats") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    }
    if (info.name == "trainable_statistics") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    }
  }
};

struct CastOpTranscriber : public OpTranscriber {
  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    auto& attribute_translator = AttributeTranslator::instance();
    pir::AttributeMap attribute_map = {};
    const OpAttributeInfo& info = op_attr_infos[0];

    std::string legacy_attr_name("out_dtype");

    paddle::framework::Attribute legacy_attr;
    if (op_desc.HasAttr(legacy_attr_name)) {
      legacy_attr = op_desc.GetAttr(legacy_attr_name);
    }
    VLOG(10) << "attribute in " << op_desc.Type()
             << " name: " << legacy_attr_name << " " << legacy_attr.index();
    pir::Attribute new_attr = attribute_translator(info.type_name, legacy_attr);
    attribute_map[info.name] = new_attr;

#ifdef PADDLE_WITH_DNNL
    if (op_desc.HasAttr("mkldnn_data_type")) {  // NOLINT
      attribute_map["mkldnn_data_type"] = pir::StrAttribute::get(
          ctx, op_desc.GetAttrIfExists<std::string>("mkldnn_data_type"));
    }
    if (op_desc.HasAttr("onednn_data_type")) {  // NOLINT
      attribute_map["onednn_data_type"] = pir::StrAttribute::get(
          ctx, op_desc.GetAttrIfExists<std::string>("onednn_data_type"));
    }
#endif
    return attribute_map;
  }
};

struct LeakyReLUOpTranscriber : public OpTranscriber {
  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    auto& attribute_translator = AttributeTranslator::instance();
    auto& op_normalizer = OpNameNormalizer::instance();
    pir::AttributeMap attribute_map = {};

    for (const auto& info : op_attr_infos) {
      auto legacy_attr_name =
          op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
      VLOG(10) << "[op: " << op_desc.Type()
               << "][attr] from: " << legacy_attr_name << " to: " << info.name;
      if (op_desc.HasAttr(legacy_attr_name)) {
        paddle::framework::Attribute legacy_attr =
            op_desc.GetAttr(legacy_attr_name);
        VLOG(10) << "attribute in " << op_desc.Type()
                 << " name: " << legacy_attr_name << " " << legacy_attr.index();
        pir::Attribute new_attr =
            attribute_translator(info.type_name, legacy_attr);
        if (legacy_attr_name == "alpha") {
          new_attr = pir::DoubleAttribute::get(
              ctx,
              static_cast<double>(
                  new_attr.dyn_cast<pir::FloatAttribute>().data()));
        }
        attribute_map[info.name] = new_attr;
      }
    }

    return attribute_map;
  }
};

struct InterpolateOpTranscriber : public OpTranscriber {
  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    auto& attribute_translator = AttributeTranslator::instance();
    auto& op_normalizer = OpNameNormalizer::instance();
    pir::AttributeMap attribute_map = {};

    for (const auto& info : op_attr_infos) {
      auto legacy_attr_name =
          op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
      VLOG(10) << "[op: " << op_desc.Type()
               << "][attr] from: " << legacy_attr_name << " to: " << info.name;
      if (op_desc.HasAttr(legacy_attr_name)) {
        paddle::framework::Attribute legacy_attr =
            op_desc.GetAttr(legacy_attr_name);
        VLOG(10) << "attribute in " << op_desc.Type()
                 << " name: " << legacy_attr_name << " " << legacy_attr.index();
        pir::Attribute new_attr =
            attribute_translator(info.type_name, legacy_attr);
        if (legacy_attr_name == "scale") {
          // Convert scale from float array to double array
          if (new_attr.isa<pir::ArrayAttribute>()) {
            auto array_attr = new_attr.dyn_cast<pir::ArrayAttribute>();
            auto array_vec = array_attr.AsVector();
            if (!array_vec.empty() && array_vec[0].isa<pir::FloatAttribute>()) {
              std::vector<pir::Attribute> double_attrs;
              double_attrs.reserve(array_vec.size());
              for (const auto& attr : array_vec) {
                double_attrs.push_back(pir::DoubleAttribute::get(
                    ctx,
                    static_cast<double>(
                        attr.dyn_cast<pir::FloatAttribute>().data())));
              }
              new_attr = pir::ArrayAttribute::get(ctx, double_attrs);
            }
          }
        }
        attribute_map[info.name] = new_attr;
      }
    }

    return attribute_map;
  }
};

struct Conv2dOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "padding_algorithm") {
      (*attribute_map)[info.name] = pir::StrAttribute::get(ctx, "EXPLICIT");
    }
  }
};

struct Conv3dOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "padding_algorithm") {
      (*attribute_map)[info.name] = pir::StrAttribute::get(ctx, "EXPLICIT");
    }
  }
};

struct ScaleOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "bias") {
      (*attribute_map)[info.name] =
          paddle::dialect::ScalarAttribute::get(ctx, phi::Scalar(0.0));
    } else if (info.name == "bias_after_scale") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, true);
    }
  }
};

struct DropoutOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "mode") {
      (*attribute_map)[info.name] =
          pir::StrAttribute::get(ctx, "downscale_in_infer");
    }
  }
};

struct SequencePoolOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "pad_value") {
      (*attribute_map)[info.name] = pir::FloatAttribute::get(ctx, 0.0f);
    }
  }
};

using ValueInfo =
    std::tuple<std::vector<int64_t>, dialect::DenseTensorType, pir::Value>;

ValueInfo GetTensorInfoByVarName(const OpDesc& op_desc,
                                 const std::vector<std::string>& names,
                                 TranslationContext* param_map,
                                 const std::string& var_name) {
  PADDLE_ENFORCE_EQ(
      names.size(),
      1UL,
      common::errors::InvalidArgument(
          "Expected op[%s]'s input %s has only 1 variable, but got %d",
          op_desc.Type(),
          var_name,
          names.size()));
  const auto& name = names[0];
  PADDLE_ENFORCE_GT(
      param_map->count(name),
      0UL,
      common::errors::InvalidArgument(
          "Expected op[%s]'s input %s has been parsed", op_desc.Type(), name));
  const auto& defining_info = param_map->at(name);

  pir::Value value = defining_info.value;
  PADDLE_ENFORCE_NE(
      value,
      nullptr,
      common::errors::PreconditionNotMet(
          "Expected op[%s]'s input %s is not null", op_desc.Type(), name));
  const pir::Type& type = value.type();
  PADDLE_ENFORCE_EQ(type.isa<dialect::DenseTensorType>(),
                    true,
                    common::errors::InvalidArgument(
                        "Expected op[%s]'s input %s is DenseTensor but got %s",
                        op_desc.Type(),
                        name,
                        type));
  dialect::DenseTensorType tensor_type =
      type.dyn_cast<dialect::DenseTensorType>();

  std::vector<int64_t> shape = common::vectorize(tensor_type.dims());

  return std::make_tuple(shape, tensor_type, value);
}

struct EmbeddingOpTranscriber : public OpTranscriber {
  void RecordOpResultMapping(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Operation* operation,
                             const OpOutputMapping& arg_to_idx) override {
    OpTranscriber::RecordOpResultMapping(
        ctx, param_map, op_desc, operation, arg_to_idx);
    if (op_desc.Type() == "lookup_table") {
      ValueInfo out_info = GetTensorInfoByVarName(
          op_desc, op_desc.Output("Out"), param_map, "Out");
      const auto& output_vars = op_desc.Output("Out");
      const auto& output_name = output_vars[0];

      pir::Value& out_value = std::get<2>(out_info);
      pir::Builder builder(ctx, operation->GetParent());
      std::vector<int64_t> axis = {-2};
      dialect::SqueezeOp squeeze_op_out =
          builder.Build<dialect::SqueezeOp>(out_value, axis);
      pir::Value out_new = squeeze_op_out.out();
      param_map->PushValue(output_name,
                           VariableDefiningInfo(out_new, false, -1));
    }
  }

  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    auto op_info = ctx->GetRegisteredOpInfo("pd_op.embedding");
    if (!op_info) {
      IR_THROW("Op %d should have corresponding OpInfo %d",
               op_desc.Type(),
               "pd_op.embedding");
    }
    return op_info;
  }

  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "padding_idx") {
      (*attribute_map)[info.name] = pir::Int64Attribute::get(ctx, -1);
    } else if (info.name == "sparse") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    }
  }
};

struct IncrementOpTranscriber : public OpTranscriber {
  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    auto& attribute_translator = AttributeTranslator::instance();
    pir::AttributeMap attribute_map = {};

    paddle::framework::Attribute legacy_attr;
    if (op_desc.HasAttr("step")) {
      legacy_attr = op_desc.GetAttr("step");
      VLOG(10) << "attribute in " << op_desc.Type() << " step: "
               << " " << legacy_attr.index();
      pir::Attribute new_attr = attribute_translator(legacy_attr);
      attribute_map["value"] = new_attr;
    } else {
      attribute_map["value"] = pir::FloatAttribute::get(ctx, 1.0f);
    }

    return attribute_map;
  }
};

// The `assign_value` in static_ops.yaml is different from the one in
// `dygraph_ops.yaml`. For this op we simulate the logic in
// python/paddle/tensor/creation.py::assign(x, output)
struct AssignValueOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = "pd_op.assign_value";
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_ENFORCE(false,
                     common::errors::InvalidArgument(
                         "Op assign_value should have corresponding OpInfo "
                         "pd_op.assign_value"));
    }

    return op_info;
  }

  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
    VLOG(10) << "[op assign_value] start transcribing";
    auto op_info = this->LookUpOpInfo(ctx, op_desc);
    auto* op_info_concept =
        op_info.GetInterfaceImpl<dialect::OpYamlInfoInterface>();
    OpInputInfoList input_infos;
    OpAttributeInfoList attr_infos;
    OpOutputInfoList output_infos;
    std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) =
        op_info_concept->get_op_info_(op_info.name());
    std::unordered_map<std::string, OpAttributeInfo> attr_info_maps;
    for (auto const& info : attr_infos) {
      attr_info_maps.insert({info.name, info});
    }

    auto& attribute_translator = AttributeTranslator::instance();
    pir::AttributeMap attribute_map;

    paddle::framework::Attribute legacy_attr;
    if (op_desc.HasAttr("shape")) {
      legacy_attr = op_desc.GetAttr("shape");
    } else {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op assign_value should have attribute `shape` but not find"));
    }
    pir::Attribute attr_shape =
        attribute_translator(attr_info_maps.at("shape").type_name, legacy_attr);
    attribute_map["shape"] = attr_shape;

    if (op_desc.HasAttr("dtype")) {
      legacy_attr = op_desc.GetAttr("dtype");
    } else {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op assign_value should have attribute `dtype` but not find"));
    }
    pir::Attribute attr_dtype =
        attribute_translator(attr_info_maps.at("dtype").type_name, legacy_attr);
    attribute_map["dtype"] = attr_dtype;

    pir::Attribute attr_place = dialect::PlaceAttribute::get(
        ctx, phi::Place(phi::AllocationType::UNDEFINED));
    attribute_map["place"] = attr_place;

    const std::vector<std::string> possible_attrs = {
        "bool_values", "fp32_values", "int32_values", "int64_values", "values"};
    for (const auto& attr_name : possible_attrs) {
      if (!op_desc.HasAttr(attr_name)) {
        continue;
      }
      legacy_attr = op_desc.GetAttr(attr_name);
      pir::Attribute attr_values = attribute_translator(
          attr_info_maps.at("values").type_name, legacy_attr);
      if (attr_values && attr_values.isa<pir::ArrayAttribute>() &&
          !attr_values.dyn_cast<pir::ArrayAttribute>().empty()) {
        attribute_map["values"] = attr_values;
        VLOG(10) << "[op assign_value][values]" << attr_name << " "
                 << attr_values;
        break;
      }
    }

    PADDLE_ENFORCE_NE(attribute_map.find("values"),
                      attribute_map.end(),
                      common::errors::InvalidArgument(
                          "Op assign_value should have attribute "
                          "`**_values` or `values` but not find"));

    TranslateOpDistAttribute(op_desc, &attribute_map);

    VLOG(10) << "[op assign_value] attribute translation done";

    std::vector<pir::Value> op_inputs = {};

    OpOutputMapping arg_to_idx;
    OpOutputTypeList op_output_types;
    std::tie(op_output_types, arg_to_idx) =
        this->GenerateOperationOutput(ctx, op_desc, output_infos);

    pir::Operation* operation = pir::Operation::Create(
        op_inputs, attribute_map, op_output_types, op_info);
    block->push_back(operation);
    RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx);

    VLOG(10) << "[op assign_value] translation finished";

    return operation;
  }
};

// This input `dropout_state_in` does not exist in static version definition
// So we generate an input by `full` with same type of output `DropoutState` of
// OpDesc And we still should be aware that `DropoutState` is an optional output
// in static graph.
pir::Value TranslateDropOutStateIn(pir::IrContext* ctx,
                                   TranslationContext* param_map,
                                   const OpDesc& op_desc,
                                   const std::string& normalized_op_name,
                                   const OpInputInfo& input_info,
                                   pir::Block* block) {
  const std::string legacy_output_name = "DropoutState";
  std::vector<std::string> legacy_output_vars;
  if (op_desc.HasOutput(legacy_output_name)) {
    legacy_output_vars = op_desc.Output(legacy_output_name);
  }

  if (legacy_output_vars.empty()) {
    VLOG(3) << "[input translating] not find output variable: DropoutState";
    return pir::Value(nullptr);
  }

  // `DropoutState` is a tensor
  VarDesc* dropout_state =
      op_desc.Block()->FindVarRecursive(legacy_output_vars[0]);
  PADDLE_ENFORCE_NE(
      dropout_state,
      nullptr,
      common::errors::InvalidArgument("[op:%s] Output %s should not be null",
                                      op_desc.Type(),
                                      legacy_output_vars[0]));
  auto& type_translator = TypeTranslator::instance();
  pir::Type translated_var_type =
      type_translator[dropout_state->GetType()](ctx, *dropout_state);
  PADDLE_ENFORCE_EQ(
      translated_var_type.isa<dialect::DenseTensorType>(),
      true,
      common::errors::InvalidArgument(
          "Unexpected: Rnn Op's output DropoutState should be a DenseTensor"));
  auto tensor_type = translated_var_type.dyn_cast<dialect::DenseTensorType>();

  pir::Builder builder(ctx, block);
  dialect::FullOp full_op = builder.Build<dialect::FullOp>(
      common::vectorize(tensor_type.dims()),
      0.0f,
      dialect::TransToPhiDataType(tensor_type.dtype()),
      phi::CPUPlace());

  return full_op->result(0);
}

// `rnn` has an additional input in dynamic graph
struct RnnOpTranscriber : public OpTranscriber {
  InputHandlerFn GetSpecialInputHandlers(
      const std::string& input_name) override {
    if (input_name != "dropout_state_in") {
      return nullptr;
    }
    return TranslateDropOutStateIn;
  };
};

struct EmbeddingGradOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "padding_idx") {
      (*attribute_map)[info.name] = pir::Int64Attribute::get(ctx, -1);
    } else if (info.name == "sparse") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    }
  }

  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name =
        GetPrefix(ctx, op_desc) + OpNameCompatibleMapping(op_desc.Type());

    bool is_sparse = paddle::get<bool>(op_desc.GetAttr("is_sparse"));

    if (is_sparse) {
      target_op_name = "pd_op.embedding_sparse_grad";
    } else {
      target_op_name = "pd_op.embedding_grad";
    }
    VLOG(6) << "[op name normalizing: " << op_desc.Type() << " to "
            << target_op_name;
    auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op %d should have corresponding OpInfo %d",
          op_desc.Type(),
          target_op_name));
    }

    return op_info;
  }
};

struct FeedOpTranscriber : public OpTranscriber {
  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    pir::AttributeMap attribute_map = {
        {"name", pir::StrAttribute::get(ctx, op_desc.OutputArgumentNames()[0])},
        {"col",
         pir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists<int>("col"))},
    };

    return attribute_map;
  }

  std::vector<pir::Value> GenerateOperationInput(
      pir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      pir::Block* block) override {
    return {};
  }
};

struct DataOpTranscriber : public FeedOpTranscriber {
  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    int allocate_type = PADDLE_GET_CONST(int, op_desc.GetAttr("place"));
    int var_dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype"));
    auto phi_dtype = phi::TransToPhiDataType(var_dtype);

    auto& attribute_translator = AttributeTranslator::instance();
    pir::Attribute shape = attribute_translator(
        "paddle::dialect::IntArrayAttribute", op_desc.GetAttr("shape"));
    pir::AttributeMap attribute_map = {
        {"name",
         pir::StrAttribute::get(ctx,
                                op_desc.GetAttrIfExists<std::string>("name"))},
        {"shape", shape},
        {"dtype", paddle::dialect::DataTypeAttribute::get(ctx, phi_dtype)},
        {"place",
         paddle::dialect::PlaceAttribute::get(
             ctx, phi::Place(static_cast<phi::AllocationType>(allocate_type)))},
    };

    if (static_cast<phi::AllocationType>(allocate_type) ==
        phi::AllocationType::CUSTOM) {
      int place_device_id =
          PADDLE_GET_CONST(int, op_desc.GetAttr("place_device_id"));
      std::string place_device_type =
          PADDLE_GET_CONST(std::string, op_desc.GetAttr("place_device_type"));
      attribute_map["place"] = paddle::dialect::PlaceAttribute::get(
          ctx,
          phi::Place(static_cast<phi::AllocationType>(allocate_type),
                     place_device_id,
                     place_device_type));
    }

    return attribute_map;
  }
};

struct CrossEntropyWithSoftmaxOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "axis") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, -1);
    }
  }
};

struct BoxCoderOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "axis") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, 0);
    }
    if (info.name == "variance") {
      std::vector<pir::Attribute> variance;
      (*attribute_map)[info.name] = pir::ArrayAttribute::get(ctx, variance);
    }
  }
};

struct Im2sequenceOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "out_stride") {
      std::vector<pir::Attribute> vec_out_stride;
      std::vector<int> out_stride = {1, 1};
      for (size_t i = 0; i < static_cast<size_t>(out_stride.size()); i++) {
        pir::Attribute attr_out_stride =
            pir::Int32Attribute::get(pir::IrContext::Instance(), out_stride[i]);

        vec_out_stride.push_back(attr_out_stride);
      }
      (*attribute_map)[info.name] =
          pir::ArrayAttribute::get(ctx, vec_out_stride);
    }
  }
};

struct DepthwiseConv2dOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "padding_algorithm") {
      (*attribute_map)[info.name] = pir::StrAttribute::get(ctx, "EXPLICIT");
    }
  }
};

struct SplitOpTranscriber : public OpTranscriber {
  std::vector<pir::Value> GenerateOperationInput(
      pir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      pir::Block* block) override {
    // input of split is [Tensor x, IntArray sections, Scalar(int) axis)]

    VLOG(10) << "[op:split][input] start";

    std::vector<pir::Value> op_inputs;
    // process first input
    auto x_input_vars = op_desc.Input("X");
    PADDLE_ENFORCE_EQ(
        x_input_vars.size(),
        1UL,
        common::errors::InvalidArgument("x input of split MUST be a tensor"));
    auto x_defining_info = (*param_map)[x_input_vars[0]];
    op_inputs.push_back(x_defining_info.value);

    // process sections
    int num = paddle::get<int>(op_desc.GetAttr("num"));
    if (num <= 0) {
      if (op_desc.HasInput("SectionsTensorList") &&
          !op_desc.Input("SectionsTensorList").empty()) {
        // get SectionsTensorList from input

        auto sec_tensor_list = op_desc.Input("SectionsTensorList");
        auto* combine_op = InsertCombineOperationForTarget(
            ctx, param_map, block, sec_tensor_list);
        op_inputs.push_back(combine_op->result(0));
      } else {
        auto& attribute_translator = AttributeTranslator::instance();
        pir::Attribute new_attr = attribute_translator(
            "paddle::dialect::IntArrayAttribute", op_desc.GetAttr("sections"));
        auto sec_define_op =
            InsertFullArrayOperationForAttributeInput(ctx, block, new_attr);
        op_inputs.push_back(sec_define_op->result(0));
      }
    }

    // process axis
    if (op_desc.HasInput("AxisTensor") &&
        !op_desc.Input("AxisTensor").empty()) {
      // get axis from input
      auto axis_var_list = op_desc.Input("AxisTensor");
      PADDLE_ENFORCE_EQ(axis_var_list.size(),
                        1UL,
                        common::errors::InvalidArgument(
                            "axis tensor input of split MUST be a tensor"));
      auto axis_defining_info = (*param_map)[axis_var_list[0]];
      op_inputs.push_back(axis_defining_info.value);
    } else {
      auto& attribute_translator = AttributeTranslator::instance();
      pir::Attribute new_attr =
          attribute_translator("pir::Int32Attribute", op_desc.GetAttr("axis"));

      auto sec_define_op =
          InsertFullOperationForAttributeInput(ctx, block, new_attr);
      op_inputs.push_back(sec_define_op->result(0));
    }

    return op_inputs;
  }

  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    int num = paddle::get<int>(op_desc.GetAttr("num"));
    if (num > 0) {
      pir::AttributeMap attribute_map = {
          {"num",
           pir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists<int>("num"))},
      };

      return attribute_map;
    }
#ifdef PADDLE_WITH_DNNL
    else {  // NOLINT
      pir::AttributeMap attribute_map = {};
      if (op_desc.HasAttr("mkldnn_data_type")) {
        attribute_map["mkldnn_data_type"] = pir::StrAttribute::get(
            ctx, op_desc.GetAttrIfExists<std::string>("mkldnn_data_type"));
      }
      if (op_desc.HasAttr("onednn_data_type")) {
        attribute_map["onednn_data_type"] = pir::StrAttribute::get(
            ctx, op_desc.GetAttrIfExists<std::string>("onednn_data_type"));
      }
      return attribute_map;
    }
#endif

    return {};
  }

  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    int num = paddle::get<int>(op_desc.GetAttr("num"));
    auto prefix = GetPrefix(ctx, op_desc);
    std::string target_op_name;
    if (num > 0) {
      target_op_name = prefix + "split_with_num";

    } else {
      target_op_name = prefix + "split";
    }

    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op assign_value should have corresponding OpInfo %s.",
          target_op_name));
    }

    return op_info;
  }
};

struct FetchOpTranscriber : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
    auto op_info = this->LookUpOpInfo(ctx, op_desc);

    auto* op_info_concept =
        op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
    OpInputInfoList input_infos;
    OpAttributeInfoList attr_infos;
    OpOutputInfoList output_infos;
    std::tie(input_infos, attr_infos, output_infos, std::ignore, std::ignore) =
        op_info_concept->get_op_info_(op_info.name());

    this->InsertSliceOperationForInput(
        ctx, param_map, op_desc, input_infos, block);

    auto op_inputs = this->GenerateOperationInput(
        ctx, param_map, op_desc, op_info.name(), input_infos, block);

    OpOutputTypeList op_output_types;
    pir::AttributeMap attribute_map = {
        {"name", pir::StrAttribute::get(ctx, op_desc.InputArgumentNames()[0])},
        {"col",
         pir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists<int>("col"))},
    };
    TranslateOpDistAttribute(op_desc, &attribute_map);

    op_output_types.push_back(op_inputs[0].type());
    pir::Operation* operation = pir::Operation::Create(
        op_inputs, attribute_map, op_output_types, op_info);
    block->push_back(operation);

    return operation;
  }
};

struct ShadowOutputOpTranscriber : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
    auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name());

    std::vector<pir::Value> op_inputs;
    auto legacy_input_vars = op_desc.Input("x", true);

    auto defining_info = (*param_map)[legacy_input_vars[0]];
    if (defining_info.generated_by_vector) {
      InsertSliceOperationForTarget(
          ctx, param_map, block, defining_info, legacy_input_vars[0]);
      defining_info = param_map->at(legacy_input_vars[0]).value;
    }

    op_inputs.push_back(defining_info.value);

    pir::AttributeMap attribute_map = {
        {"output_name",
         pir::StrAttribute::get(ctx,
                                op_desc.GetAttrIfExists<std::string>("name"))},
    };
    TranslateOpDistAttribute(op_desc, &attribute_map);

    pir::Operation* operation =
        pir::Operation::Create(op_inputs, attribute_map, {}, op_info);
    block->push_back(operation);

    return operation;
  }
};

// NOTE, add_n op in legacy ops don't have a kernel, so we use a new op for now
struct AddNOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name =
        GetPrefix(ctx, op_desc) + OpNameCompatibleMapping(op_desc.Type());
    if (IsInplace(op_desc)) {
      target_op_name += "_";
    }

    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op add_n should have corresponding OpInfo %s", target_op_name));
    }

    return op_info;
  }
};

struct TrilAndTriuOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    bool lower = PADDLE_GET_CONST(bool, op_desc.GetAttr("lower"));
    std::string target_op_name = "";
    if (lower) {
      target_op_name = "pd_op.tril";
    } else {
      target_op_name = "pd_op.triu";
    }
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op tril_triu should have corresponding "
          "OpInfo pd_op.tril or pd_op.triu."));
    }

    return op_info;
  }
};

struct TrilAndTriuGradOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    bool lower = PADDLE_GET_CONST(bool, op_desc.GetAttr("lower"));
    std::string target_op_name = "";
    if (lower) {
      target_op_name = "pd_op.tril_grad";
    } else {
      target_op_name = "pd_op.triu_grad";
    }
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op tril_triu_grad should have "
          "corresponding OpInfo pd_op.tril_grad "
          "or "
          "pd_op.triu_grad."));
    }

    return op_info;
  }
};

struct MulOpTranscriber : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
#ifdef PADDLE_WITH_DNNL
    if (op_desc.GetAttrIfExists<bool>("use_mkldnn") ||
        op_desc.GetAttrIfExists<bool>("use_onednn")) {
      return static_cast<OpTranscriber>(*this).operator()(  // NOLINT
          ctx,
          param_map,
          op_desc,
          block);
    }
#endif
    return OpTranscriber::operator()(ctx, param_map, op_desc, block);
  }

  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    const std::string& target_op_name = paddle::dialect::MatmulOp::name();
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op %d should have corresponding OpInfo %d",
          op_desc.Type(),
          target_op_name));
    }
    return op_info;
  }

  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    pir::AttributeMap attribute_map = {};

    attribute_map["transpose_x"] = pir::BoolAttribute::get(ctx, false);
    attribute_map["transpose_y"] = pir::BoolAttribute::get(ctx, false);

    return attribute_map;
  }

  std::vector<pir::Value> GenerateOperationInput(
      pir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      pir::Block* block) override {
    const int x_num_col_dims =
        PADDLE_GET_CONST(int, op_desc.GetAttr("x_num_col_dims"));
    const int y_num_col_dims =
        PADDLE_GET_CONST(int, op_desc.GetAttr("y_num_col_dims"));

    ValueInfo x_info = GetTensorInfoByVarName(
        op_desc, op_desc.Input("X", true), param_map, "X");

    const auto& [x_shape, x_tensor_type, x_value] = x_info;

    PADDLE_ENFORCE_EQ(
        x_num_col_dims <= static_cast<int>(x_shape.size()),
        true,
        common::errors::InvalidArgument(
            "Expected op[%s]'s attr `x_num_col_dims` less than or equal to "
            "dim of input X %s, but got %d",
            op_desc.Type(),
            x_shape.size(),
            x_num_col_dims));

    ValueInfo y_info = GetTensorInfoByVarName(
        op_desc, op_desc.Input("Y", true), param_map, "Y");

    const auto& [y_shape, y_tensor_type, y_value] = y_info;

    PADDLE_ENFORCE_EQ(
        y_num_col_dims <= static_cast<int>(y_shape.size()),
        true,
        common::errors::InvalidArgument(
            "Expected op[%s]'s attr `y_num_col_dims` less than or equal to "
            "dim of input Y %s, but got %d",
            op_desc.Type(),
            y_shape.size(),
            y_num_col_dims));

    pir::Builder builder(ctx, block);

    std::vector<int64_t> x_new_shape({
        std::max(std::accumulate(x_shape.begin(),
                                 x_shape.begin() + x_num_col_dims,
                                 static_cast<int64_t>(1),
                                 std::multiplies<int64_t>()),
                 static_cast<int64_t>(-1)),
        std::max(std::accumulate(x_shape.begin() + x_num_col_dims,
                                 x_shape.end(),
                                 static_cast<int64_t>(1),
                                 std::multiplies<int64_t>()),
                 static_cast<int64_t>(-1)),
    });
    dialect::ReshapeOp reshape_op_x =
        builder.Build<dialect::ReshapeOp>(x_value, x_new_shape);
    pir::Value x_new = reshape_op_x.out();
    VLOG(6) << "[" << op_desc.Type() << "] x_shape change from "
            << x_tensor_type.dims() << " to " << common::make_ddim(x_new_shape);

    std::vector<int64_t> y_new_shape(
        {std::max(std::accumulate(y_shape.begin(),
                                  y_shape.begin() + y_num_col_dims,
                                  static_cast<int64_t>(1),
                                  std::multiplies<int64_t>()),
                  static_cast<int64_t>(-1)),
         std::max(std::accumulate(y_shape.begin() + y_num_col_dims,
                                  y_shape.end(),
                                  static_cast<int64_t>(1),
                                  std::multiplies<int64_t>()),
                  static_cast<int64_t>(-1))});

    dialect::ReshapeOp reshape_op_y =
        builder.Build<dialect::ReshapeOp>(y_value, y_new_shape);
    pir::Value y_new = reshape_op_y.out();
    VLOG(6) << "[" << op_desc.Type() << "] y_shape change from "
            << y_tensor_type.dims() << " to " << common::make_ddim(y_new_shape);

    return {x_new, y_new};
  }

  void RecordOpResultMapping(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Operation* operation,
                             const OpOutputMapping& arg_to_idx) override {
    OpTranscriber::RecordOpResultMapping(
        ctx, param_map, op_desc, operation, arg_to_idx);
    if (op_desc.HasOutput("Out")) {
      ValueInfo out_info = GetTensorInfoByVarName(
          op_desc, op_desc.Output("Out"), param_map, "Out");

      const dialect::DenseTensorType& out_tensor_type = std::get<1>(out_info);
      pir::Value& out_value = std::get<2>(out_info);

      const auto& output_vars = op_desc.Output("Out");
      const auto& output_name = output_vars[0];

      const int x_num_col_dims =
          PADDLE_GET_CONST(int, op_desc.GetAttr("x_num_col_dims"));
      const int y_num_col_dims =
          PADDLE_GET_CONST(int, op_desc.GetAttr("y_num_col_dims"));

      ValueInfo x_info = GetTensorInfoByVarName(
          op_desc, op_desc.Input("X", true), param_map, "X");

      const std::vector<int64_t>& x_shape = std::get<0>(x_info);

      ValueInfo y_info = GetTensorInfoByVarName(
          op_desc, op_desc.Input("Y", true), param_map, "Y");

      const std::vector<int64_t>& y_shape = std::get<0>(y_info);

      std::vector<int64_t> out_new_shape(x_shape.begin(),
                                         x_shape.begin() + x_num_col_dims);
      out_new_shape.insert(
          out_new_shape.end(), y_shape.begin() + y_num_col_dims, y_shape.end());

      pir::Builder builder(ctx, operation->GetParent());
      dialect::ReshapeOp reshape_op_out =
          builder.Build<dialect::ReshapeOp>(out_value, out_new_shape);
      pir::Value out_new = reshape_op_out.out();
      VLOG(6) << "[" << op_desc.Type() << "] out_shape change from "
              << out_tensor_type.dims() << " to "
              << common::make_ddim(out_new_shape);

      param_map->PushValue(output_name,
                           VariableDefiningInfo(out_new, false, -1));
    }
  }
};

struct MulGradOpTranscriber : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
#ifdef PADDLE_WITH_DNNL
    if (op_desc.GetAttrIfExists<bool>("use_mkldnn") ||
        op_desc.GetAttrIfExists<bool>("use_onednn")) {
      return static_cast<OpTranscriber>(*this).operator()(  // NOLINT
          ctx,
          param_map,
          op_desc,
          block);
    }
#endif
    return OpTranscriber::operator()(ctx, param_map, op_desc, block);
  }

  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    const std::string& target_op_name = paddle::dialect::MatmulGradOp::name();
    VLOG(6) << "[op name normalizing: " << op_desc.Type() << " to "
            << target_op_name;
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op %d should have corresponding OpInfo %d",
          op_desc.Type(),
          target_op_name));
    }
    return op_info;
  }

  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    pir::AttributeMap attribute_map = {};

    attribute_map["transpose_x"] = pir::BoolAttribute::get(ctx, false);
    attribute_map["transpose_y"] = pir::BoolAttribute::get(ctx, false);

    return attribute_map;
  }

  std::vector<pir::Value> GenerateOperationInput(
      pir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      pir::Block* block) override {
    const int x_num_col_dims =
        PADDLE_GET_CONST(int, op_desc.GetAttr("x_num_col_dims"));
    const int y_num_col_dims =
        PADDLE_GET_CONST(int, op_desc.GetAttr("y_num_col_dims"));

    ValueInfo x_info = GetTensorInfoByVarName(
        op_desc, op_desc.Input("X", true), param_map, "X");

    const auto& [x_shape, x_tensor_type, x_value] = x_info;

    PADDLE_ENFORCE_EQ(
        x_num_col_dims <= static_cast<int>(x_shape.size()),
        true,
        common::errors::InvalidArgument(
            "Expected op[%s]'s attr `x_num_col_dims` less than or equal to "
            "dim of input X %s, but got %d",
            op_desc.Type(),
            x_shape.size(),
            x_num_col_dims));

    ValueInfo y_info = GetTensorInfoByVarName(
        op_desc, op_desc.Input("Y", true), param_map, "Y");

    const auto& [y_shape, y_tensor_type, y_value] = y_info;

    PADDLE_ENFORCE_EQ(
        y_num_col_dims <= static_cast<int>(y_shape.size()),
        true,
        common::errors::InvalidArgument(
            "Expected op[%s]'s attr `y_num_col_dims` less than or equal to "
            "dim of input Y %s, but got %d",
            op_desc.Type(),
            y_shape.size(),
            y_num_col_dims));

    ValueInfo out_grad_info = GetTensorInfoByVarName(
        op_desc, op_desc.Input("Out@GRAD", true), param_map, "Out@GRAD");

    const dialect::DenseTensorType& out_grad_tensor_type =
        std::get<1>(out_grad_info);
    pir::Value& out_grad_value = std::get<2>(out_grad_info);

    pir::Builder builder(ctx, block);

    std::vector<int64_t> x_new_shape({
        std::max(std::accumulate(x_shape.begin(),
                                 x_shape.begin() + x_num_col_dims,
                                 static_cast<int64_t>(1),
                                 std::multiplies<int64_t>()),
                 static_cast<int64_t>(-1)),
        std::max(std::accumulate(x_shape.begin() + x_num_col_dims,
                                 x_shape.end(),
                                 static_cast<int64_t>(1),
                                 std::multiplies<int64_t>()),
                 static_cast<int64_t>(-1)),
    });
    dialect::ReshapeOp reshape_op_x =
        builder.Build<dialect::ReshapeOp>(x_value, x_new_shape);
    pir::Value x_new = reshape_op_x.out();
    VLOG(6) << "[" << op_desc.Type() << "] x_shape change from "
            << x_tensor_type.dims() << " to " << common::make_ddim(x_new_shape);

    std::vector<int64_t> y_new_shape(
        {std::max(std::accumulate(y_shape.begin(),
                                  y_shape.begin() + y_num_col_dims,
                                  static_cast<int64_t>(1),
                                  std::multiplies<int64_t>()),
                  static_cast<int64_t>(-1)),
         std::max(std::accumulate(y_shape.begin() + y_num_col_dims,
                                  y_shape.end(),
                                  static_cast<int64_t>(1),
                                  std::multiplies<int64_t>()),
                  static_cast<int64_t>(-1))});

    dialect::ReshapeOp reshape_op_y =
        builder.Build<dialect::ReshapeOp>(y_value, y_new_shape);
    pir::Value y_new = reshape_op_y.out();
    VLOG(6) << "[" << op_desc.Type() << "] y_shape change from "
            << y_tensor_type.dims() << " to " << common::make_ddim(y_new_shape);

    std::vector<int64_t> out_grad_new_shape(
        {x_new_shape.front(), y_new_shape.back()});

    dialect::ReshapeOp reshape_op_out_grad =
        builder.Build<dialect::ReshapeOp>(out_grad_value, out_grad_new_shape);
    pir::Value out_grad_new = reshape_op_out_grad.out();
    VLOG(6) << "[" << op_desc.Type() << "] out_grad_shape change from "
            << out_grad_tensor_type.dims() << " to "
            << common::make_ddim(out_grad_new_shape);

    return {x_new, y_new, out_grad_new};
  }

  void RecordOpResultMapping(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Operation* operation,
                             const OpOutputMapping& arg_to_idx) override {
    OpTranscriber::RecordOpResultMapping(
        ctx, param_map, op_desc, operation, arg_to_idx);

    const auto& x_grad_output = op_desc.Output("X@GRAD");
    const auto& y_grad_output = op_desc.Output("Y@GRAD");
    if (x_grad_output.size() < 1 && y_grad_output.size() < 1) {
      return;
    }

    pir::Builder builder(ctx, operation->GetParent());

    auto gradReshape = [&](const std::string& var_name) {
      const auto& grad_output = op_desc.Output(var_name);
      PADDLE_ENFORCE_EQ(
          grad_output.size(),
          1UL,
          common::errors::InvalidArgument(
              "Expected op[%s]'s output %s has only 1 variable, but got %d",
              op_desc.Type(),
              var_name,
              grad_output.size()));
      const auto& grad_var_name = grad_output[0];

      auto idx_iter = arg_to_idx.find(grad_var_name);
      if (idx_iter == arg_to_idx.end()) {
        PADDLE_THROW(common::errors::InvalidArgument(
            "op[%s] should have got its %s", op_desc.Type(), var_name));
      }
      auto [idx_in_op, idx_in_vec] = idx_iter->second;
      VLOG(10) << "[output recording]"
               << "[" << op_desc.Type() << "]" << grad_var_name << " "
               << idx_in_op << " " << idx_in_vec;

      VarDesc* var_desc = op_desc.Block()->FindVarRecursive(
          op_desc.Input(var_name.substr(0, 1))[0]);
      PADDLE_ENFORCE_NE(
          var_desc,
          nullptr,
          common::errors::InvalidArgument("[op:%s] Input %s should not be null",
                                          op_desc.Type(),
                                          var_name.substr(0, 1)));
      std::vector<int64_t> shape = var_desc->GetShape();
      DenseTensorTypeStorage::Dim dim = common::make_ddim(shape);

      pir::Value value_res = operation->result(idx_in_op);
      auto reshape_op = builder.Build<dialect::ReshapeOp>(value_res, shape);
      PADDLE_ENFORCE_NE(value_res,
                        nullptr,
                        common::errors::PreconditionNotMet(
                            "Expected op[%s]'s input %s is not null",
                            op_desc.Type(),
                            grad_var_name));
      pir::Type grad_type = value_res.type();
      PADDLE_ENFORCE_EQ(
          grad_type.isa<dialect::DenseTensorType>(),
          true,
          common::errors::InvalidArgument(
              "Expected op[%s]'s input %s is DenseTensor but got %s",
              op_desc.Type(),
              grad_var_name,
              grad_type));
      dialect::DenseTensorType grad_tensor_type =
          grad_type.dyn_cast<dialect::DenseTensorType>();

      VLOG(10) << "[" << op_desc.Type() << "] shape of " << var_name
               << " change from " << grad_tensor_type.dims() << " to " << dim;

      param_map->PushValue(grad_var_name,
                           VariableDefiningInfo(reshape_op.out(), false, -1));
    };

    if (x_grad_output.size()) {
      gradReshape("X@GRAD");
    }

    if (y_grad_output.size()) {
      gradReshape("Y@GRAD");
    }
  }
};

struct FillConstant2FullTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    const auto& op_info = ctx->GetRegisteredOpInfo(dialect::FullOp::name());
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op fill_constant should have corresponding OpInfo pd_op.full"));
    }

    return op_info;
  }

  std::vector<pir::Value> GenerateOperationInput(
      pir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      pir::Block* block) override {
    return {};
  }

  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    auto& attribute_translator = AttributeTranslator::instance();
    paddle::framework::Attribute shape_attr = op_desc.GetAttr("shape");
    float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value"));
    int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype"));

    auto attr_value = pir::FloatAttribute::get(ctx, value);

    pir::AttributeMap attribute_map = {
        {"shape",
         attribute_translator("paddle::dialect::IntArrayAttribute",
                              shape_attr)},
        {"value", attr_value.dyn_cast<paddle::dialect::ScalarAttribute>()},
        {"dtype",
         paddle::dialect::DataTypeAttribute::get(
             ctx,
             paddle::translator::VarTypeToDataType(
                 static_cast<paddle::framework::proto::VarType_Type>(dtype)))}};

    int place_type{-1};
    if (op_desc.HasAttr("place_type")) {
      place_type = PADDLE_GET_CONST(int, op_desc.GetAttr("place_type"));
    }
    if (op_desc.HasAttr("force_cpu")) {
      bool force_cpu = PADDLE_GET_CONST(bool, op_desc.GetAttr("force_cpu"));
      if (force_cpu) {
        place_type = 0;
      }
    }
    switch (place_type) {
      case -1:  // NOLINT
        attribute_map["place"] = paddle::dialect::PlaceAttribute::get(
            ctx, phi::Place(phi::AllocationType::UNDEFINED));
        break;
      case 0:
        attribute_map["place"] =
            paddle::dialect::PlaceAttribute::get(ctx, phi::CPUPlace());
        break;
      case 1:
        attribute_map["place"] =
            paddle::dialect::PlaceAttribute::get(ctx, phi::GPUPlace());
        break;
      case 2:
        attribute_map["place"] =
            paddle::dialect::PlaceAttribute::get(ctx, phi::GPUPinnedPlace());
        break;
      case 3:
        attribute_map["place"] =
            paddle::dialect::PlaceAttribute::get(ctx, phi::XPUPlace());
        break;
    }

    return attribute_map;
  }
};

struct FillConstant2FullWithTensorTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    const auto& op_info = ctx->GetRegisteredOpInfo("pd_op.full_with_tensor");
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op fill_constant should have corresponding OpInfo "
          "pd_op.full_with_tensor"));
    }

    return op_info;
  }

  std::vector<pir::Value> GenerateOperationInput(
      pir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      pir::Block* block) override {
    std::vector<pir::Value> op_inputs;
    if (op_desc.HasInput("ValueTensor", true) &&
        op_desc.Input("ValueTensor", true).size() > 0) {
      auto value_tensor_vars = op_desc.Input("ValueTensor", true);
      auto defining_info = (*param_map)[value_tensor_vars[0]];
      op_inputs.push_back(defining_info.value);
    } else {
      float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value"));
      pir::Attribute new_attr = pir::FloatAttribute::get(ctx, value);
      auto defining_op =
          InsertFullOperationForAttributeInput(ctx, block, new_attr);
      op_inputs.push_back(defining_op->result(0));
    }

    if (op_desc.HasInput("ShapeTensor", true) &&
        op_desc.Input("ShapeTensor", true).size() > 0) {
      auto shape_tensor_vars = op_desc.Input("ShapeTensor", true);
      auto defining_info = (*param_map)[shape_tensor_vars[0]];
      op_inputs.push_back(defining_info.value);
    } else if (op_desc.HasInput("ShapeTensorList", true) &&
               op_desc.Input("ShapeTensorList", true).size() > 0) {
      auto shape_tensor_list_vars = op_desc.Input("ShapeTensorList", true);
      auto defining_op = InsertStackOperationForTarget(
          ctx, param_map, block, shape_tensor_list_vars);
      op_inputs.push_back(defining_op->result(0));
    } else {
      auto& attribute_translator = AttributeTranslator::instance();
      paddle::framework::Attribute shape_attr = op_desc.GetAttr("shape");
      pir::Attribute new_attr = attribute_translator(
          "paddle::dialect::IntArrayAttribute", shape_attr);
      auto defining_op =
          InsertFullArrayOperationForAttributeInput(ctx, block, new_attr);
      op_inputs.push_back(defining_op->result(0));
    }

    return op_inputs;
  }

  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype"));

    pir::AttributeMap attribute_map = {
        {"dtype",
         paddle::dialect::DataTypeAttribute::get(
             ctx,
             paddle::translator::VarTypeToDataType(
                 static_cast<paddle::framework::proto::VarType_Type>(dtype)))}};
    return attribute_map;
  }
};

struct FillConstantTranscriber : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
    bool has_mutable_attribute = op_desc.HasInput("ShapeTensor", true) &&
                                 op_desc.Input("ShapeTensor", true).size() > 0;
    has_mutable_attribute |= op_desc.HasInput("ShapeTensorList", true) &&
                             op_desc.Input("ShapeTensorList", true).size() > 0;
    has_mutable_attribute |= op_desc.HasInput("ValueTensor", true) &&
                             op_desc.Input("ValueTensor", true).size() > 0;

    if (!has_mutable_attribute) {
      return FillConstant2FullTranscriber()(ctx, param_map, op_desc, block);
    } else {
      return FillConstant2FullWithTensorTranscriber()(
          ctx, param_map, op_desc, block);
    }
  }
};

struct SelectInputOpTranscriber : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
    VLOG(10) << "[op select_input] start transcribing";
    auto op_info = this->LookUpOpInfo(ctx, op_desc);

    std::vector<pir::Value> op_inputs = {};
    auto Mask_name = op_desc.Input("Mask")[0];
    auto& Input_name = op_desc.Input("X");
    PADDLE_ENFORCE_GT(param_map->count(Mask_name),
                      0UL,
                      common::errors::InvalidArgument(
                          "Expected op[%s]'s input %s has been parsed",
                          op_desc.Type(),
                          Mask_name));
    op_inputs.push_back(param_map->at(Mask_name).value);
    for (auto in_name : Input_name) {
      PADDLE_ENFORCE_GT(param_map->count(in_name),
                        0UL,
                        common::errors::InvalidArgument(
                            "Expected op[%s]'s input %s has been parsed",
                            op_desc.Type(),
                            in_name));
      op_inputs.push_back(param_map->at(in_name).value);
    }

    pir::AttributeMap attribute_map;
    TranslateOpDistAttribute(op_desc, &attribute_map);

    OpOutputMapping arg_to_idx;
    OpOutputTypeList op_output_types;
    auto Out_name = op_desc.Output("Out")[0];
    VarDesc* var = op_desc.Block()->FindVarRecursive(Out_name);
    arg_to_idx[var->Name()] = {0, 0};

    auto input1 = op_inputs[1].type();
    auto input2 = op_inputs[2].type();
    if (input1 == input2) {
      op_output_types.push_back(op_inputs[1].type());
    } else if (input1.isa<paddle::dialect::DenseTensorType>() &&
               input2.isa<paddle::dialect::DenseTensorType>()) {
      auto tensor1 = input1.dyn_cast<paddle::dialect::DenseTensorType>();
      auto tensor2 = input2.dyn_cast<paddle::dialect::DenseTensorType>();
      if (tensor1.dtype() != tensor2.dtype() ||
          tensor1.data_layout() != tensor2.data_layout() ||
          tensor1.lod() != tensor2.lod() ||
          tensor1.offset() != tensor2.offset()) {
        const std::string undefined_prefix = "undefined_var_";

        size_t undefined_var_index = 0;
        size_t target_var_index = 1;
        if (Input_name[undefined_var_index].substr(
                0, undefined_prefix.size()) != undefined_prefix &&
            Input_name[target_var_index].substr(0, undefined_prefix.size()) ==
                undefined_prefix) {
          std::swap(undefined_var_index, target_var_index);
        } else if (Input_name[undefined_var_index].substr(
                       0, undefined_prefix.size()) == undefined_prefix) {
          // do nothing
        } else {
          PADDLE_THROW(common::errors::InvalidArgument(
              "select_input only support same type or DenseTensorType with "
              "only different dim, but get dtype:[%s, %s], layout:[%s, "
              "%s], "
              "lod:[%s, %s], offset:[%s, %s].",
              tensor1.dtype(),
              tensor2.dtype(),
              tensor1.data_layout(),
              tensor2.data_layout(),
              tensor1.lod(),
              tensor2.lod(),
              tensor1.offset(),
              tensor2.offset()));
        }

        auto undefined_var_type = tensor1;
        auto target_var_type = tensor2;
        if (undefined_var_index == 1) {
          std::swap(undefined_var_type, target_var_type);
        }

        auto undefine_value = op_inputs[1 + undefined_var_index];
        PADDLE_ENFORCE_EQ(
            undefine_value.defining_op()->isa<dialect::AssignValueOp>(),
            true,
            common::errors::InvalidArgument(
                "undefined_var %s should be generated "
                "by assign_value, but got %s",
                Input_name[undefined_var_index],
                undefine_value.defining_op()));

        undefine_value.set_type(target_var_type);
        undefine_value.defining_op()->set_attribute(
            "dtype",
            dialect::DataTypeAttribute::get(
                ctx, dialect::TransToPhiDataType(undefined_var_type.dtype())));
        auto& attribute_translator = AttributeTranslator::instance();
        undefine_value.defining_op()->set_attribute(
            "shape",
            attribute_translator("pir::ArrayAttribute<pir::Int32Attribute>",
                                 common::vectorize(undefined_var_type.dims())));
      }
      auto dim1 = input1.dyn_cast<paddle::dialect::DenseTensorType>().dims();
      auto dim2 = input2.dyn_cast<paddle::dialect::DenseTensorType>().dims();
      auto compute_compatible_dim =
          [](const common::DDim& dim1,
             const common::DDim& dim2) -> common::DDim {
        std::vector<int64_t> result;
        for (int i = 0; i < std::min(dim1.size(), dim2.size()); ++i) {
          if (dim1[i] != dim2[i]) {
            result.push_back(-1);
          } else {
            result.push_back(dim1[i]);
          }
        }
        return common::make_ddim(result);
      };
      auto dim = compute_compatible_dim(dim1, dim2);
      op_output_types.push_back(
          paddle::dialect::DenseTensorType::get(ctx,
                                                tensor1.dtype(),
                                                dim,
                                                tensor1.data_layout(),
                                                tensor1.lod(),
                                                tensor1.offset()));
    } else {
      PADDLE_THROW(common::errors::InvalidArgument(
          "select_input only support same type or "
          "DenseTensorType with only "
          "different dim, now is %s != %s.",
          input1,
          input2));
    }

    pir::Operation* operation = pir::Operation::Create(
        op_inputs, attribute_map, op_output_types, op_info);
    block->push_back(operation);
    RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx);

    VLOG(10) << "[op assign_value] translation finished";
    return operation;
  }
};

struct SelectOutputOpTranscriber : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
    VLOG(10) << "[op select_output] start transcribing";
    auto op_info = this->LookUpOpInfo(ctx, op_desc);

    std::vector<pir::Value> op_inputs = {};
    auto Mask_name = op_desc.Input("Mask")[0];
    auto& Input_name = op_desc.Input("X")[0];
    PADDLE_ENFORCE_GT(param_map->count(Mask_name),
                      0UL,
                      common::errors::InvalidArgument(
                          "Expected op[%s]'s input %s has been parsed",
                          op_desc.Type(),
                          Mask_name));
    op_inputs.push_back(param_map->at(Mask_name).value);
    PADDLE_ENFORCE_GT(param_map->count(Input_name),
                      0UL,
                      common::errors::InvalidArgument(
                          "Expected op[%s]'s input %s has been parsed",
                          op_desc.Type(),
                          Input_name));
    op_inputs.push_back(param_map->at(Input_name).value);

    pir::AttributeMap attribute_map;
    TranslateOpDistAttribute(op_desc, &attribute_map);

    OpOutputMapping arg_to_idx;
    OpOutputTypeList op_output_types;
    auto Out_names = op_desc.Output("Out");
    PADDLE_ENFORCE_EQ(Out_names.size(),
                      2UL,
                      common::errors::InvalidArgument(
                          "Expected SelectOutput's output size is 2."));
    for (size_t idx = 0; idx < Out_names.size(); idx++) {
      VarDesc* var = op_desc.Block()->FindVarRecursive(Out_names[idx]);
      arg_to_idx[var->Name()] = {idx, 0};
      op_output_types.push_back(op_inputs[1].type());
    }

    pir::Operation* operation = pir::Operation::Create(
        op_inputs, attribute_map, op_output_types, op_info);
    block->push_back(operation);
    RecordOpResultMapping(ctx, param_map, op_desc, operation, arg_to_idx);

    VLOG(10) << "[op assign_value] translation finished";
    return operation;
  }
};

pir::Value TranslateNumClassesForOneHot(pir::IrContext* ctx,
                                        TranslationContext* param_map,
                                        const OpDesc& op_desc,
                                        const std::string& normalized_op_name,
                                        const OpInputInfo& input_info,
                                        pir::Block* block) {
  const std::string legacy_attr_name = "depth";
  const std::string legacy_tensor_name = "depth_tensor";
  std::vector<std::string> legacy_vars;
  if (op_desc.HasInput(legacy_tensor_name) &&
      !op_desc.Input(legacy_tensor_name).empty()) {
    legacy_vars = op_desc.Input(legacy_tensor_name);
    PADDLE_ENFORCE_EQ(legacy_vars.size(),
                      1UL,
                      common::errors::InvalidArgument(
                          "depth_tensor input of one hot MUST be a tensor"));
    auto var_name = legacy_vars[0];
    PADDLE_ENFORCE_EQ(legacy_vars.size(),
                      1UL,
                      common::errors::InvalidArgument(
                          "depth_tensor input of one hot MUST be a tensor"));
    PADDLE_ENFORCE_NE(
        param_map->count(legacy_vars[0]),
        0UL,
        common::errors::InvalidArgument(
            "%s should be existed in one_hot_v2 as input depth_tensor.",
            legacy_vars[0]));
    auto defining_info = param_map->at(legacy_vars[0]);
    return defining_info.value;
  }

  auto& attribute_translator = AttributeTranslator::instance();
  if (!op_desc.HasAttr(legacy_attr_name)) {
    PADDLE_THROW(
        common::errors::InvalidArgument("Op %s arg %s should not be zero size",
                                        op_desc.Type(),
                                        legacy_attr_name));
  }
  paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name);
  VLOG(10) << "[" << op_desc.Type() << "][attribute]"
           << " name: " << legacy_attr_name << " " << legacy_attr.index();
  pir::Attribute new_attr = attribute_translator(legacy_attr);

  pir::Operation* defining_op =
      InsertFullOperationForAttributeInput(ctx, block, new_attr);
  return defining_op->result(0);
}

struct OneHotTranscriber : public OpTranscriber {
  InputHandlerFn GetSpecialInputHandlers(
      const std::string& input_name) override {
    if (input_name != "num_classes") {
      return nullptr;
    }
    return TranslateNumClassesForOneHot;
  };
};

struct Pool2dOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "exclusive") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, true);
    }
    if (info.name == "adaptive") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    }
    if (info.name == "padding_algorithm") {
      (*attribute_map)[info.name] = pir::StrAttribute::get(ctx, "EXPLICIT");
    }
  }
};

struct Pool3dOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "exclusive") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, true);
    }
    if (info.name == "adaptive") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    }
    if (info.name == "padding_algorithm") {
      (*attribute_map)[info.name] = pir::StrAttribute::get(ctx, "EXPLICIT");
    }
  }
};

pir::Attribute TranslateDtypeForArange(pir::IrContext* ctx,
                                       const OpDesc& op_desc,
                                       const OpAttributeInfo& attr_info) {
  PADDLE_ENFORCE_EQ(
      op_desc.Input("Start").size(),
      1UL,
      common::errors::InvalidArgument(
          "[op:%s] Input [Start]'s size should be equal to 1", op_desc.Type()));
  auto var_desc = op_desc.Block()->FindVarRecursive(op_desc.Input("Start")[0]);
  PADDLE_ENFORCE_NE(
      var_desc,
      nullptr,
      common::errors::InvalidArgument("[op:%s] Input %s should not be null",
                                      op_desc.Type(),
                                      op_desc.Input("Start")[0]));
  auto start_proto_dtype = var_desc->GetDataType();
  auto start_phi_dtype = phi::TransToPhiDataType(start_proto_dtype);
  auto dtype_attr =
      paddle::dialect::DataTypeAttribute::get(ctx, start_phi_dtype);
  return dtype_attr;
}

struct ArangeOpTranscriber : public OpTranscriber {
  AttributeHandlerFn GetSpecialAttributeHandlers(
      const std::string& attr_name) override {
    if (attr_name != "dtype") {
      return nullptr;
    }
    return TranslateDtypeForArange;
  }
};

pir::Attribute TranslateReduceAll(pir::IrContext* ctx,
                                  const OpDesc& op_desc,
                                  const OpAttributeInfo& attr_info) {
  bool reduce_all = false;
  if (op_desc.HasAttr("reduce_all")) {
    reduce_all = paddle::get<bool>(op_desc.GetAttr("reduce_all"));
  }

  if (reduce_all) {
    return pir::ArrayAttribute::get(ctx, std::vector<pir::Attribute>{});
  }

  auto& attribute_translator = AttributeTranslator::instance();
  auto& op_normalizer = OpNameNormalizer::instance();
  auto legacy_attr_name =
      op_normalizer.GetLegacyAttrName(op_desc.Type(), attr_info.name);
  paddle::framework::Attribute dims = op_desc.GetAttr(legacy_attr_name);
  return attribute_translator(attr_info.type_name, dims);
}

struct ReduceOpTranscriber : public OpTranscriber {
  AttributeHandlerFn GetSpecialAttributeHandlers(
      const std::string& attr_name) override {
    if (attr_name != "axis") {
      return nullptr;
    }
    return TranslateReduceAll;
  }
};

struct ElementwiseTranscriber : public OpTranscriber {
  std::vector<pir::Value> GenerateOperationInput(
      pir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      pir::Block* block) override {
    int axis = paddle::get<int>(op_desc.GetAttr("axis"));

    if (axis == -1) {
      return OpTranscriber::GenerateOperationInput(
          ctx, param_map, op_desc, normalized_op_name, input_infos, block);
    }

    auto x_names = op_desc.Input("X", true);
    PADDLE_ENFORCE_EQ(
        x_names.size(),
        1UL,
        common::errors::InvalidArgument(
            "Expected op[%s]'s input X has only 1 variable, but got %d",
            op_desc.Type(),
            x_names.size()));
    auto x_name = x_names[0];
    PADDLE_ENFORCE_GT(param_map->count(x_name),
                      0UL,
                      common::errors::InvalidArgument(
                          "Expected op[%s]'s input %s has been parsed",
                          op_desc.Type(),
                          x_name));
    auto x_defining_info = param_map->at(x_name);
    if (x_defining_info.generated_by_vector) {
      InsertSliceOperationForTarget(
          ctx, param_map, block, x_defining_info, x_name);
      x_defining_info = param_map->at(x_name);
    }
    pir::Value x_value = x_defining_info.value;
    PADDLE_ENFORCE_NE(
        x_value,
        nullptr,
        common::errors::PreconditionNotMet(
            "Expected op[%s]'s input %s is not null", op_desc.Type(), x_name));
    pir::Type x_type = x_value.type();
    PADDLE_ENFORCE_EQ(
        x_type.isa<dialect::DenseTensorType>(),
        true,
        common::errors::InvalidArgument(
            "Expected op[%s]'s input %s is DenseTensor but got %s",
            op_desc.Type(),
            x_name,
            x_type));
    dialect::DenseTensorType x_tensor_type =
        x_type.dyn_cast<dialect::DenseTensorType>();
    std::vector<int64_t> x_shape = common::vectorize(x_tensor_type.dims());

    auto y_names = op_desc.Input("Y", true);
    PADDLE_ENFORCE_EQ(
        y_names.size(),
        1UL,
        common::errors::InvalidArgument(
            "Expected op[%s]'s input Y has only 1 variable, but got %d",
            op_desc.Type(),
            y_names.size()));
    auto y_name = y_names[0];
    PADDLE_ENFORCE_GT(param_map->count(y_name),
                      0UL,
                      common::errors::InvalidArgument(
                          "Expected op[%s]'s input %s has been parsed",
                          op_desc.Type(),
                          y_name));
    auto y_defining_info = param_map->at(y_name);
    if (y_defining_info.generated_by_vector) {
      InsertSliceOperationForTarget(
          ctx, param_map, block, y_defining_info, y_name);
      y_defining_info = param_map->at(y_name);
    }
    pir::Value y_value = y_defining_info.value;
    PADDLE_ENFORCE_NE(
        y_value,
        nullptr,
        common::errors::PreconditionNotMet(
            "Expected op[%s]'s input %s is not null", op_desc.Type(), y_name));
    pir::Type y_type = y_value.type();
    PADDLE_ENFORCE_EQ(
        y_type.isa<dialect::DenseTensorType>(),
        true,
        common::errors::InvalidArgument(
            "Expected op[%s]'s input %s is DenseTensor but got %s",
            op_desc.Type(),
            y_name,
            y_type));
    dialect::DenseTensorType y_tensor_type =
        y_type.dyn_cast<dialect::DenseTensorType>();
    std::vector<int64_t> y_shape = common::vectorize(y_tensor_type.dims());

    if (axis < 0) {
      axis += static_cast<int>(x_shape.size());
    }

    int append_size = static_cast<int>(x_shape.size() - axis - y_shape.size());
    if (append_size <= 0) {  // which means x.rank <= y.rank, mostly
                             // x.rank=y.rank
      return {x_value, y_value};
    }
    PADDLE_ENFORCE_GT(
        append_size,
        0UL,
        common::errors::InvalidArgument(
            "Expected op[%s] have append size > 0 with axis=%d but got %d",
            op_desc.Type(),
            axis,
            append_size));

    pir::Builder builder(ctx, block);
    pir::Value y_new;
    if (std::find(y_shape.begin(), y_shape.end(), -1) == y_shape.end()) {
      std::vector<int64_t> y_new_shape(y_shape);
      y_new_shape.insert(y_new_shape.begin(), axis, 1);
      for (int i = 0; i < append_size; i++) {
        y_new_shape.push_back(1);
      }
      dialect::ReshapeOp reshape_op =
          builder.Build<dialect::ReshapeOp>(y_value, y_new_shape);
      y_new = reshape_op.out();
      VLOG(6) << "[" << op_desc.Type() << "] y_shape change from "
              << y_tensor_type.dims() << " to "
              << common::make_ddim(y_new_shape);
    } else {
      auto shape_op = builder.Build<dialect::Shape64Op>(y_value);
      auto append_shape_op = builder.Build<dialect::FullIntArrayOp>(
          std::vector<int64_t>(append_size, 1),
          phi::DataType::INT64,
          phi::CPUPlace());
      auto y_true_shape_op = builder.Build<pir::CombineOp>(
          std::vector<pir::Value>{shape_op.out(), append_shape_op.out()});
      auto concat_op =
          builder.Build<dialect::ConcatOp>(y_true_shape_op.out(), 0);
      auto y_new_shape = concat_op.out();
      auto reshape_op = builder.Build<dialect::ReshapeOp>(y_value, y_new_shape);
      y_new = reshape_op.out();
    }
    return {x_value, y_new};
  }
};

struct GradAddOpTranscriber : public ElementwiseTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = "pd_op.add";
    if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') {
      target_op_name += "_";
    }
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op assign_value should have corresponding OpInfo "
          "pd_op.assign_value_"));
    }

    return op_info;
  }
};

struct ElementwiseGradTranscriber : public OpTranscriber {
  void RecordOpResultMapping(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Operation* operation,
                             const OpOutputMapping& arg_to_idx) override {
    OpTranscriber::RecordOpResultMapping(
        ctx, param_map, op_desc, operation, arg_to_idx);

    int axis = paddle::get<int>(op_desc.GetAttr("axis"));
    if (axis == -1) {
      return;
    }

    const auto& y_grad_output = op_desc.Output("Y@GRAD");
    if (y_grad_output.size() < 1) {
      return;
    }
    PADDLE_ENFORCE_EQ(
        y_grad_output.size(),
        1UL,
        common::errors::InvalidArgument(
            "Expected op[%s]'s output Y@GRAD has only 1 variable, but got %d",
            op_desc.Type(),
            y_grad_output.size()));
    const auto& y_grad_var_name = y_grad_output[0];

    auto idx_iter = arg_to_idx.find(y_grad_var_name);
    if (idx_iter == arg_to_idx.end()) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "op[%s] should have got its y_grad", op_desc.Type()));
    }
    auto [idx_in_op, idx_in_vec] = idx_iter->second;
    VLOG(10) << "[output recording]"
             << "[" << op_desc.Type() << "]" << y_grad_var_name << " "
             << idx_in_op << " " << idx_in_vec;

    auto y_names = op_desc.Input("Y", true);
    auto y_name = y_names[0];
    PADDLE_ENFORCE_GT(param_map->count(y_name),
                      0UL,
                      common::errors::InvalidArgument(
                          "Expected op[%s]'s input %s has been parsed",
                          op_desc.Type(),
                          y_name));
    auto y_defining_info = param_map->at(y_name);
    pir::Value y_value = y_defining_info.value;
    PADDLE_ENFORCE_NE(
        y_value,
        nullptr,
        common::errors::PreconditionNotMet(
            "Expected op[%s]'s input %s is not null", op_desc.Type(), y_name));
    pir::Type y_type = y_value.type();
    PADDLE_ENFORCE_EQ(
        y_type.isa<dialect::DenseTensorType>(),
        true,
        common::errors::InvalidArgument(
            "Expected op[%s]'s input %s is DenseTensor but got %s",
            op_desc.Type(),
            y_name,
            y_type));
    dialect::DenseTensorType y_tensor_type =
        y_type.dyn_cast<dialect::DenseTensorType>();

    pir::Value value = operation->result(idx_in_op);

    // if y_grad' shape is same with y, we don't need a reshape
    pir::Type y_grad_type = value.type();
    PADDLE_ENFORCE_EQ(
        y_grad_type.isa<dialect::DenseTensorType>(),
        true,
        common::errors::InvalidArgument(
            "Expected op[%s]'s input %s is DenseTensor but got %s",
            op_desc.Type(),
            y_grad_var_name,
            y_grad_type));
    dialect::DenseTensorType y_grad_tensor_type =
        y_grad_type.dyn_cast<dialect::DenseTensorType>();
    if (y_grad_tensor_type.dims() == y_tensor_type.dims()) {
      return;
    }

    std::vector<int64_t> y_shape = common::vectorize(y_tensor_type.dims());
    pir::Builder builder(ctx, operation->GetParent());
    auto reshape_op = builder.Build<dialect::ReshapeOp>(value, y_shape);
    param_map->PushValue(y_grad_var_name,
                         VariableDefiningInfo(reshape_op.out(), false, -1));
  }
};

struct SetValueOpTranscriber : public OpTranscriber {
  pir::Value GetAttributeAsInput(pir::IrContext* ctx,
                                 pir::Block* block,
                                 const OpDesc& op_desc,
                                 const OpInputInfo& input_info) override {
    auto& attribute_translator = AttributeTranslator::instance();
    auto& op_normalizer = OpNameNormalizer::instance();

    auto legacy_attr_name =
        op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name);

    if (!op_desc.HasAttr(legacy_attr_name)) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op %s arg %s should not be zero size",
          op_desc.Type(),
          legacy_attr_name));
    }
    framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name);
    VLOG(10) << "[" << op_desc.Type() << "][attribute]"
             << " name: " << legacy_attr_name << " " << legacy_attr.index();
    pir::Attribute new_attr =
        attribute_translator("paddle::dialect::IntArrayAttribute", legacy_attr);

    pir::Operation* defining_op =
        InsertFullArrayOperationForAttributeInput(ctx, block, new_attr);
    return defining_op->result(0);
  }
};

struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = dialect::SetValueWithTensorOp::name();
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op set_value should have corresponding OpInfo "
          "pd_op.set_value_with_tensor"));
    }

    return op_info;
  }

  InputHandlerFn GetSpecialInputHandlers(
      const std::string& input_name) override {
    if (input_name != "values") {
      return nullptr;
    }
    return [](pir::IrContext* ctx,
              TranslationContext* param_map,
              const OpDesc& op_desc,
              const std::string&,
              const OpInputInfo& info,
              pir::Block* block) -> pir::Value {
      std::vector<std::string> legacy_input_vars;
      PADDLE_ENFORCE_EQ(op_desc.HasInput("ValueTensor"),
                        true,
                        common::errors::InvalidArgument(
                            "[set_value] should have ValueTensor"));
      legacy_input_vars = op_desc.Input("ValueTensor", true);
      PADDLE_ENFORCE_EQ(legacy_input_vars.size(),
                        1UL,
                        common::errors::InvalidArgument(
                            "[set_value][ValueTensor] should only "
                            "have 1 variable, but got %d",
                            legacy_input_vars.size()));
      auto var_name = legacy_input_vars[0];
      auto defining_info = (*param_map)[var_name];
      if (defining_info.generated_by_vector) {
        InsertSliceOperationForTarget(
            ctx, param_map, block, defining_info, var_name);
        defining_info = param_map->at(var_name).value;
      }
      return defining_info.value;
    };
  }
};

struct SetValueGradOpTranscriber : public SetValueWithTensorOpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = dialect::SetValueWithTensorGradOp::name();
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op set_value_grad should have corresponding OpInfo "
          "pd_op.set_value_with_tensor_grad"));
    }

    return op_info;
  }
};

struct LegacySetValueDispatcher : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
    std::vector<std::string> legacy_input_vars;

    // if op has input with name "ValueTensor", then use that input as value
    if (op_desc.HasInput("ValueTensor")) {
      legacy_input_vars = op_desc.Input("ValueTensor", true);
      if (legacy_input_vars.size() > 0) {
        VLOG(10) << "legacy op:" << op_desc.Type()
                 << " has ValueTensor and convert to set_value_with_tensor";
        return SetValueWithTensorOpTranscriber()(
            ctx, param_map, op_desc, block);
      }
    }

    return SetValueOpTranscriber()(ctx, param_map, op_desc, block);
  }
};

struct FusedFeedForwardOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "ln1_epsilon") {
      (*attribute_map)[info.name] = pir::FloatAttribute::get(ctx, 1e-5f);
    } else if (info.name == "ln2_epsilon") {
      (*attribute_map)[info.name] = pir::FloatAttribute::get(ctx, 1e-5f);
    } else if (info.name == "act_method") {
      (*attribute_map)[info.name] = pir::StrAttribute::get(ctx, "gelu");
    } else if (info.name == "dropout1_prob") {
      (*attribute_map)[info.name] = pir::FloatAttribute::get(ctx, .5f);
    } else if (info.name == "dropout2_prob") {
      (*attribute_map)[info.name] = pir::FloatAttribute::get(ctx, .5f);
    } else if (info.name == "dropout1_implementation") {
      (*attribute_map)[info.name] =
          pir::StrAttribute::get(ctx, "downgrade_in_infer");
    } else if (info.name == "dropout2_implementation") {
      (*attribute_map)[info.name] =
          pir::StrAttribute::get(ctx, "downgrade_in_infer");
    } else if (info.name == "is_test") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    } else if (info.name == "dropout1_fix_seed") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    } else if (info.name == "dropout2_fix_seed") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    } else if (info.name == "dropout1_seed_val") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, false);
    } else if (info.name == "dropout2_seed_val") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, false);
    } else if (info.name == "add_residual") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, true);
    } else if (info.name == "ring_id") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, -1);
    }
  }

  void RecordOpResultMapping(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Operation* operation,
                             const OpOutputMapping& arg_to_idx) override {
    OpTranscriber::RecordOpResultMapping(
        ctx, param_map, op_desc, operation, arg_to_idx);
    if (op_desc.HasOutput("Out")) {
      const auto& output_vars = op_desc.Output("Out");
      PADDLE_ENFORCE_EQ(output_vars.size(),
                        1UL,
                        common::errors::InvalidArgument(
                            "Expected op[%s]'s Out has only 1 var but got %s",
                            op_desc.Type(),
                            output_vars.size()));
      auto output_var = output_vars[0];
      auto fused_feedforward_op =
          operation->dyn_cast<dialect::FusedFeedforwardOp>();
      param_map->PushValue(output_var,
                           VariableDefiningInfo{fused_feedforward_op.out()});
    }
  }
};

struct ShareBufferOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = dialect::ShareData_Op::name();
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op share_buffer should have corresponding OpInfo "
          "pd_op.share_data_"));
    }

    return op_info;
  }
};

struct RandIntOpTranscriber : public OpTranscriber {
  std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
      pir::IrContext* ctx,
      const OpDesc& op_desc,
      const OpOutputInfoList& output_infos) override {
    OpOutputMapping arg_to_idx;
    OpOutputTypeList op_output_types = {};

    auto& type_translator = TypeTranslator::instance();

    const BlockDesc* block = op_desc.Block();
    std::string legacy_output_name = "Out";
    const auto& legacy_output_vars = op_desc.Output(legacy_output_name);
    auto& var_name = legacy_output_vars[0];
    VarDesc* var = block->FindVarRecursive(var_name);
    PADDLE_ENFORCE_NE(
        var,
        nullptr,
        common::errors::InvalidArgument(
            "[op:%s] Output %s should not be null", op_desc.Type(), var_name));
    int dtype_attr_val = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype"));

    paddle::framework::proto::VarType::Type var_type =
        static_cast<paddle::framework::proto::VarType::Type>(dtype_attr_val);

    pir::Type dtype = type_translator[var_type](ctx, *var);
    paddle::dialect::DenseTensorTypeStorage::Dim dim =
        common::make_ddim(var->GetShape());
    paddle::dialect::DenseTensorTypeStorage::DataLayout layout =
        paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW;
    paddle::dialect::DenseTensorTypeStorage::LegacyLoD lod = {};
    size_t offset = 0;
    pir::Type translated_var_type = paddle::dialect::DenseTensorType::get(
        ctx, dtype, dim, layout, lod, offset);
    arg_to_idx[var_name] = {0, 0};
    op_output_types.push_back(translated_var_type);
    return {op_output_types, arg_to_idx};
  }
};

struct RepeatInterLeaveOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name;
    if (op_desc.HasInput("RepeatsTensor") &&
        !op_desc.Input("RepeatsTensor").empty()) {
      target_op_name = "pd_op.repeat_interleave_with_tensor_index";
    } else {
      target_op_name = "pd_op.repeat_interleave";
    }
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    return op_info;
  }

  std::vector<pir::Value> GenerateOperationInput(
      pir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      pir::Block* block) override {
    std::vector<pir::Value> op_inputs;
    auto x_names = op_desc.Input("X", true);
    auto input = param_map->at(x_names[0]).value;
    op_inputs.push_back(input);
    if (op_desc.HasInput("RepeatsTensor") &&
        !op_desc.Input("RepeatsTensor").empty()) {
      auto repeats_names = op_desc.Input("RepeatsTensor", true);
      input = param_map->at(repeats_names[0]).value;
      op_inputs.push_back(input);
    }
    return op_inputs;
  }
};

struct RepeatInterLeaveGradOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name;
    if (op_desc.HasInput("RepeatsTensor") &&
        !op_desc.Input("RepeatsTensor").empty()) {
      target_op_name = "pd_op.repeat_interleave_with_tensor_index_grad";
    } else {
      target_op_name = "pd_op.repeat_interleave_grad";
    }
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    return op_info;
  }

  std::vector<pir::Value> GenerateOperationInput(
      pir::IrContext* ctx,
      TranslationContext* param_map,
      const OpDesc& op_desc,
      const std::string& normalized_op_name,
      const OpInputInfoList& input_infos,
      pir::Block* block) override {
    std::vector<pir::Value> op_inputs;
    auto x_names = op_desc.Input("X", true);
    auto input = param_map->at(x_names[0]).value;
    op_inputs.push_back(input);
    if (op_desc.HasInput("RepeatsTensor") &&
        !op_desc.Input("RepeatsTensor").empty()) {
      auto repeats_names = op_desc.Input("RepeatsTensor", true);
      input = param_map->at(repeats_names[0]).value;
      op_inputs.push_back(input);
    }
    auto out_grad_names = op_desc.Input("Out@GRAD", true);
    input = param_map->at(out_grad_names[0]).value;
    op_inputs.push_back(input);

    return op_inputs;
  }
};

struct TopPSamplingOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "seed") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, -1);
    } else if (info.name == "k") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, 0);
    } else if (info.name == "mode") {
      (*attribute_map)[info.name] = pir::StrAttribute::get(ctx, "truncated");
    }
  }
};

struct FusedElemwiseAddActivationOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "scale") {
      (*attribute_map)[info.name] = pir::FloatAttribute::get(ctx, 0.0);
    } else if (info.name == "axis") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, -1);
    } else if (info.name == "save_intermediate_out") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    }
  }
};

struct FusedElemwiseAddActivationGradOpTranscriber
    : public FusedElemwiseAddActivationOpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    const auto inter_out_grad = op_desc.Output("IntermediateOut@GRAD");
    if (inter_out_grad.size() > 0) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "pd_op.fused_elemwise_add_activation_grad doesn't have "
          "Intermediate_out_grad output"));
    }

    return OpTranscriber::LookUpOpInfo(ctx, op_desc);
  }
};

// a more general version for fake quantize ops
// if one has a more special property, then don't use this
struct FakeQuantizeOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "round_type") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, 1);
    } else if (info.name == "x_num_col_dims") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, 1);
    } else if (info.name == "round_type") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, 1);
    } else if (info.name == "quant_axis") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, 0);
    }
  }
};

struct MatrixRankOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = "";
    if (op_desc.HasInput("TolTensor") && !op_desc.Input("TolTensor").empty()) {
      target_op_name = "pd_op.matrix_rank_tol";
    } else {
      target_op_name = "pd_op.matrix_rank";
    }
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op matrix_rank should have "
          "corresponding OpInfo pd_op.matrix_rank "
          "or "
          "pd_op.matrix_rank_tol."));
    }
    return op_info;
  }
};

struct LodArrayLengthOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = dialect::ArrayLengthOp::name();
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op lod_array_length should have corresponding OpInfo "
          "pd_op.array_length"));
    }

    return op_info;
  }

  InputHandlerFn GetSpecialInputHandlers(
      const std::string& input_name) override {
    if (input_name != "x") {
      return nullptr;
    }
    return [](pir::IrContext* ctx,
              TranslationContext* param_map,
              const OpDesc& op_desc,
              const std::string&,
              const OpInputInfo& info,
              pir::Block* block) -> pir::Value {
      VLOG(10) << "[" << op_desc.Type() << "][input `array`]";
      PADDLE_ENFORCE_EQ(
          op_desc.HasInput("X"),
          true,
          common::errors::InvalidArgument(
              "Op lod_array_length should have input `X` but not found"));
      const auto& vars = op_desc.Input("X");
      PADDLE_ENFORCE_EQ(
          vars.size(),
          1UL,
          common::errors::InvalidArgument("Input `X` should be one variable %s",
                                          op_desc.Type()));
      VLOG(10) << "[" << op_desc.Type() << "][input `x`] from " << vars[0];
      const VarDesc* var_desc = op_desc.Block()->FindVarRecursive(vars[0]);
      PADDLE_ENFORCE_NE(
          var_desc,
          nullptr,
          common::errors::InvalidArgument(
              "VarDesc `%s` should be exist in legacy program", vars[0]));
      auto defining_value = pir::Value(nullptr);
      if (param_map->count(var_desc->Name())) {
        VLOG(10) << "[" << op_desc.Type() << "][input `x`] var: " << vars[0]
                 << " have been created";
        defining_value = param_map->at(var_desc->Name()).value;
      } else {
        VLOG(10) << "[" << op_desc.Type() << "][input `x`] var: " << vars[0]
                 << " newly created";
        auto create_array_op = InsertCreateArrayOp(ctx, block, var_desc);
        defining_value = create_array_op->result(0);
      }
      return defining_value;
    };
  }
};

struct WriteArrayOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = dialect::ArrayWrite_Op::name();
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op write_to_array should have corresponding OpInfo "
          "pd_op.array_write_"));
    }

    return op_info;
  }

  InputHandlerFn GetSpecialInputHandlers(
      const std::string& input_name) override {
    if (input_name != "array") {
      return nullptr;
    }
    return [](pir::IrContext* ctx,
              TranslationContext* param_map,
              const OpDesc& op_desc,
              const std::string&,
              const OpInputInfo& info,
              pir::Block* block) -> pir::Value {
      VLOG(10) << "[" << op_desc.Type() << "][input `array`]";
      PADDLE_ENFORCE_EQ(
          op_desc.HasOutput("Out"),
          true,
          common::errors::InvalidArgument(
              "Op write_to_array should have output `Out` but not found"));
      const auto& vars = op_desc.Output("Out");
      PADDLE_ENFORCE_EQ(
          vars.size(),
          1UL,
          common::errors::InvalidArgument(
              "Output `Out` should be one variable %s", op_desc.Type()));
      VLOG(10) << "[" << op_desc.Type() << "][input `array`] from " << vars[0];
      const VarDesc* var_desc = op_desc.Block()->FindVarRecursive(vars[0]);
      PADDLE_ENFORCE_NE(
          var_desc,
          nullptr,
          common::errors::InvalidArgument(
              "VarDesc `%s` should be exist in legacy program", vars[0]));
      auto defining_value = pir::Value(nullptr);
      if (param_map->count(var_desc->Name())) {
        VLOG(10) << "[" << op_desc.Type() << "][input `array`] var: " << vars[0]
                 << " have been created";
        defining_value = param_map->at(var_desc->Name()).value;
      } else {
        VLOG(10) << "[" << op_desc.Type() << "][input `array`] var: " << vars[0]
                 << " newly created";
        auto create_array_op = InsertCreateArrayOp(ctx, block, var_desc);
        defining_value = create_array_op->result(0);
      }
      return defining_value;
    };
  }
};

struct ReadArrayOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = dialect::ArrayReadOp::name();
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op read_from_array should have corresponding OpInfo "
          "pd_op.read_array"));
    }

    return op_info;
  }
};

struct SliceOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = dialect::SliceOp::name();

    PADDLE_ENFORCE_EQ(op_desc.HasInput("Input"),
                      true,
                      common::errors::InvalidArgument(
                          "op %s should have input `Input`", op_desc.Type()));
    const auto& input_vars = op_desc.Input("Input");
    PADDLE_ENFORCE_EQ(input_vars.size(),
                      1UL,
                      common::errors::InvalidArgument(
                          "op %s should have one input `Input`, but got %d.",
                          op_desc.Type(),
                          input_vars.size()));
    const auto* input_var = op_desc.Block()->FindVarRecursive(input_vars[0]);
    if (input_var->GetType() == framework::proto::VarType::DENSE_TENSOR_ARRAY) {
      PADDLE_ENFORCE_EQ(op_desc.HasOutput("Out"),
                        true,
                        common::errors::InvalidArgument(
                            "op %s should have input `Out`", op_desc.Type()));
      const auto& output_vars = op_desc.Output("Out");
      PADDLE_ENFORCE_EQ(output_vars.size(),
                        1UL,
                        common::errors::InvalidArgument(
                            "op %s should have one input `Out`, but got %d.",
                            op_desc.Type(),
                            output_vars.size()));
      const auto* output_var =
          op_desc.Block()->FindVarRecursive(output_vars[0]);
      PADDLE_ENFORCE_NE(output_var,
                        nullptr,
                        common::errors::InvalidArgument(
                            "op %s should have non-empty output `%s`.",
                            op_desc.Type(),
                            output_vars[0]));

      if (output_var->GetType() == framework::proto::VarType::DENSE_TENSOR) {
        target_op_name = dialect::SliceArrayDenseOp::name();
      } else {
        target_op_name = dialect::SliceArrayOp::name();
      }
    }

    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op slice should have corresponding OpInfo %s", target_op_name));
    }

    return op_info;
  }
};

struct SoftmaxOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "axis") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, -1);
    }
  }
};

struct SoftmaxWithCrossEntropyOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "use_softmax") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, true);
    }
    if (info.name == "axis") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, -1);
    }
  }
};

struct LegacyMatmulOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    auto enforce_not_occur = [&](const std::string& attr_name,
                                 float expected_value) -> void {
      if (!op_desc.HasAttr(attr_name)) {
        return;
      }
      float v = PADDLE_GET_CONST(float, op_desc.GetAttr(attr_name));
      if (abs(v - expected_value) > 1e-6f) {
        PADDLE_THROW(common::errors::InvalidArgument(
            "Expected op[%s]'s attr %s is not %f",
            op_desc.Type(),
            attr_name,
            v));
      }
    };

    enforce_not_occur("Scale_x", 1.0f);
    enforce_not_occur("Scale_y", 1.0f);
    enforce_not_occur("Scale_out", 1.0f);

    std::string target_op_name = dialect::MatmulOp::name();
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op read_from_array should have corresponding OpInfo "
          "pd_op.read_array"));
    }

    return op_info;
  }

  void RecordOpResultMapping(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Operation* operation,
                             const OpOutputMapping& arg_to_idx) override {
    OpTranscriber::RecordOpResultMapping(
        ctx, param_map, op_desc, operation, arg_to_idx);

    float alpha = PADDLE_GET_CONST(float, op_desc.GetAttr("alpha"));
    if (abs(alpha - 1.0f) < 1e-6f) {
      return;
    }

    const auto& output_vars = op_desc.Output("Out");
    PADDLE_ENFORCE_EQ(
        output_vars.size(),
        1UL,
        common::errors::InvalidArgument(
            "Expected op[%s]'s output `Out` has only 1 variable, but got %d",
            op_desc.Type(),
            output_vars.size()));

    auto idx_iter = arg_to_idx.find(output_vars[0]);
    if (idx_iter == arg_to_idx.end()) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "op[%s] should have got its `Out`", op_desc.Type()));
    }
    auto [idx_in_op, idx_in_vec] = idx_iter->second;
    VLOG(10) << "[output recording]"
             << "[" << op_desc.Type() << "]" << output_vars[0] << " "
             << idx_in_op << " " << idx_in_vec;

    pir::Builder builder(ctx, operation->GetParent());
    pir::Value value = operation->result(idx_in_op);
    auto scale_op = builder.Build<dialect::ScaleOp>(value, alpha);
    param_map->PushValue(output_vars[0],
                         VariableDefiningInfo(scale_op.out(), false, -1));
  }

  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "transpose_x" || info.name == "transpose_y") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    }
  }
};

struct CEmbeddingOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "vocab_size") {
      (*attribute_map)[info.name] = pir::Int64Attribute::get(ctx, -1);
    }
  }
};

struct GatherOpTranscriber : public OpTranscriber {
  pir::Value GetAttributeAsInput(pir::IrContext* ctx,
                                 pir::Block* block,
                                 const OpDesc& op_desc,
                                 const OpInputInfo& input_info) override {
    auto& attribute_translator = AttributeTranslator::instance();
    auto& op_normalizer = OpNameNormalizer::instance();

    auto legacy_attr_name =
        op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name);

    if (!op_desc.HasAttr(legacy_attr_name)) {
      VLOG(10) << "[" << op_desc.Type() << "][attribute]"
               << " name: " << legacy_attr_name << " not found and fill 0.";
      pir::Attribute new_attr = pir::Int64Attribute::get(ctx, 0);
      pir::Operation* defining_op =
          InsertFullOperationForAttributeInput(ctx, block, new_attr);
      return defining_op->result(0);
    } else {
      paddle::framework::Attribute legacy_attr =
          op_desc.GetAttr(legacy_attr_name);
      VLOG(10) << "[" << op_desc.Type() << "][attribute]"
               << " name: " << legacy_attr_name << " " << legacy_attr.index();
      pir::Attribute new_attr = attribute_translator(legacy_attr);
      pir::Operation* defining_op =
          InsertFullOperationForAttributeInput(ctx, block, new_attr);
      return defining_op->result(0);
    }
  }
};

struct QuantizeLinearOpTranscriber : public OpTranscriber {
  void HandleNonexistentAttribute(pir::IrContext* ctx,
                                  pir::AttributeMap* attribute_map,
                                  const OpAttributeInfo& info) override {
    if (info.name == "round_type") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, 0);
    }
    if (info.name == "is_test") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, true);
    }
    if (info.name == "only_observer") {
      (*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
    }
    if (info.name == "moving_rate") {
      (*attribute_map)[info.name] = pir::FloatAttribute::get(ctx, 0.9);
    }
    if (info.name == "qmin") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, -128);
    }
    if (info.name == "qmax") {
      (*attribute_map)[info.name] = pir::Int32Attribute::get(ctx, 127);
    }
  }
};

// NOTE(Dev): heleper functions for WithXShapeGradOpTranscriber
static std::pair<pir::Value, pir::Value> ParseXAndOutGradValue(
    const OpDesc& op_desc,
    pir::IrContext* ctx,
    pir::Builder* builder,
    TranslationContext* param_map,
    pir::Block* block) {
  auto& input_xshape_name = op_desc.Input("XShape")[0];
  auto& input_outgrad_name = op_desc.Input("Out@GRAD")[0];
  pir::Value xshape_value;
  VLOG(10) << "create data op for " << input_xshape_name;
  auto var_desc = op_desc.Block()->FindVarRecursive(input_xshape_name);
  auto dtype = ::phi::TransToPhiDataType(var_desc->GetDataType());
  auto shape_vec = var_desc->GetShape();
  // NOTE(dev): GraphOp depends on X instead of XShape, so we need
  // erase first element in xshape.
  shape_vec.erase(shape_vec.begin());
  xshape_value = builder
                     ->Build<paddle::dialect::DataOp>(
                         input_xshape_name, shape_vec, dtype, phi::Place())
                     .result(0);

  VLOG(10) << "create data op for " << input_xshape_name << " done";

  if (param_map->Has(input_xshape_name)) {
    auto value =
        param_map->at(input_xshape_name).value.dyn_cast<pir::OpResult>();
    auto* defining_op = value.owner();
    value.ReplaceAllUsesWith(xshape_value);
    param_map->PopValue(input_xshape_name);
    defining_op->Erase();
  }

  param_map->PushValue(input_xshape_name, xshape_value);
  PADDLE_ENFORCE_EQ(param_map->Has(input_outgrad_name),
                    true,
                    common::errors::InvalidArgument(
                        "Reshape2_Grad op does not have input Out@GRAD"));
  auto input_outgrad_value_info = param_map->at(input_outgrad_name);
  if (input_outgrad_value_info.generated_by_vector) {
    InsertSliceOperationForTarget(
        ctx, param_map, block, input_outgrad_value_info, input_outgrad_name);
    input_outgrad_value_info = param_map->at(input_outgrad_name);
  }
  pir::Value input_outgrad_value = input_outgrad_value_info.value;

  PADDLE_ENFORCE_EQ(
      input_outgrad_value.type().isa<paddle::dialect::DenseTensorType>(),
      true,
      ::common::errors::InvalidArgument(
          "input type must be DenseTensorType, but received: %s.",
          input_outgrad_value.type()));

  return std::make_pair(xshape_value, input_outgrad_value);
}

static pir::Value ParseAxis(const OpDesc& op_desc,
                            TranslationContext* param_map,
                            pir::IrContext* ctx,
                            pir::Block* block) {
  // process axes
  if (op_desc.HasInput("AxesTensor") && !op_desc.Input("AxesTensor").empty()) {
    // get axis from input
    auto axis_var_list = op_desc.Input("AxesTensor");
    PADDLE_ENFORCE_EQ(
        axis_var_list.size(),
        1UL,
        common::errors::InvalidArgument(
            "axis tensor input of %s MUST be a tensor", op_desc.Type()));
    auto axis_defining_info = (*param_map)[axis_var_list[0]];
    return axis_defining_info.value;
  } else if (op_desc.HasInput("AxesTensorList") &&
             !op_desc.Input("AxesTensorList").empty()) {
    auto* combine_op = InsertCombineOperationForTarget(
        ctx, param_map, block, op_desc.Input("AxesTensorList"));
    return combine_op->result(0);
  } else {
    auto& attribute_translator = AttributeTranslator::instance();
    pir::Attribute new_attr = attribute_translator(
        "paddle::dialect::IntArrayAttribute", op_desc.GetAttr("axes"));
    auto full_array_op =
        InsertFullArrayOperationForAttributeInput(ctx, block, new_attr);
    return full_array_op->result(0);
  }
}

template <typename OpT>
struct WithXShapeGradOpTranscriber : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
    VLOG(4) << "Translate " << op_desc.Type() << ".....";
    pir::Builder builder(ctx, block);
    auto [xshape_value, input_outgrad_value] =
        ParseXAndOutGradValue(op_desc, ctx, &builder, param_map, block);
    auto& out_name = op_desc.Output("X@GRAD")[0];
    // NOTE(Aurelius84): Even though we use xshape to construct grad op,
    // but in GradKernel we still use dx->dims by default.
    OpT grad_op = builder.Build<OpT>(xshape_value, input_outgrad_value);
    param_map->PushValue(out_name, grad_op.result(0));

    return grad_op.operation();
  }
};

// NOTE(dev): In case of squeeze_grad and unsqueeze_grad
template <typename OpT>
struct WithXShapeAndAxisGradOpTranscriber : public OpTranscriber {
  pir::Operation* operator()(pir::IrContext* ctx,
                             TranslationContext* param_map,
                             const OpDesc& op_desc,
                             pir::Block* block) override {
    VLOG(4) << "Translate " << op_desc.Type() << ".....";
    pir::Builder builder(ctx, block);
    auto [x_value, input_outgrad_value] =
        ParseXAndOutGradValue(op_desc, ctx, &builder, param_map, block);
    auto& out_name = op_desc.Output("X@GRAD")[0];
    // NOTE(Aurelius84): Even though we use xshape to construct grad op,
    // but in GradKernel we still use dx->dims by default.
    pir::Value axis = ParseAxis(op_desc, param_map, ctx, block);
    OpT grad_op = builder.Build<OpT>(x_value, input_outgrad_value, axis);
    param_map->PushValue(out_name, grad_op.result(0));

    return grad_op.operation();
  }
};

struct SyncCommStreamOpTranscriber : public OpTranscriber {
  pir::OpInfo LookUpOpInfo(pir::IrContext* ctx,
                           const OpDesc& op_desc) override {
    std::string target_op_name = "pd_op.sync_comm_stream_";
    const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
    if (!op_info) {
      PADDLE_THROW(common::errors::InvalidArgument(
          "Op c_sync_comm_stream should have corresponding "
          "OpInfo pd_op.sync_comm_stream_."));
    }
    return op_info;
  }
};

struct SoftPlusOpTranscriber : public OpTranscriber {
  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    auto& attribute_translator = AttributeTranslator::instance();
    auto& op_normalizer = OpNameNormalizer::instance();
    pir::AttributeMap attribute_map = {};

    for (const auto& info : op_attr_infos) {
      auto legacy_attr_name =
          op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
      VLOG(10) << "[op: " << op_desc.Type()
               << "][attr] from: " << legacy_attr_name << " to: " << info.name;
      if (op_desc.HasAttr(legacy_attr_name)) {
        paddle::framework::Attribute legacy_attr =
            op_desc.GetAttr(legacy_attr_name);
        VLOG(10) << "attribute in " << op_desc.Type()
                 << " name: " << legacy_attr_name << " " << legacy_attr.index();
        pir::Attribute new_attr =
            attribute_translator(info.type_name, legacy_attr);
        if (legacy_attr_name == "beta" || legacy_attr_name == "threshold") {
          new_attr = pir::DoubleAttribute::get(
              ctx,
              static_cast<double>(
                  new_attr.dyn_cast<pir::FloatAttribute>().data()));
        }
        attribute_map[info.name] = new_attr;
      } else {
        this->HandleNonexistentAttribute(ctx, &attribute_map, info);
      }
    }
    return attribute_map;
  }
};

struct LogitOpTranscriber : public OpTranscriber {
  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    auto& attribute_translator = AttributeTranslator::instance();
    auto& op_normalizer = OpNameNormalizer::instance();
    pir::AttributeMap attribute_map = {};

    for (const auto& info : op_attr_infos) {
      auto legacy_attr_name =
          op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
      VLOG(10) << "[op: " << op_desc.Type()
               << "][attr] from: " << legacy_attr_name << " to: " << info.name;
      if (op_desc.HasAttr(legacy_attr_name)) {
        paddle::framework::Attribute legacy_attr =
            op_desc.GetAttr(legacy_attr_name);
        VLOG(10) << "attribute in " << op_desc.Type()
                 << " name: " << legacy_attr_name << " " << legacy_attr.index();
        pir::Attribute new_attr =
            attribute_translator(info.type_name, legacy_attr);
        if (legacy_attr_name == "eps") {
          new_attr = pir::DoubleAttribute::get(
              ctx,
              static_cast<double>(
                  new_attr.dyn_cast<pir::FloatAttribute>().data()));
        }
        attribute_map[info.name] = new_attr;
      } else {
        this->HandleNonexistentAttribute(ctx, &attribute_map, info);
      }
    }
    return attribute_map;
  }
};

struct Pad3dOpTranscriber : public OpTranscriber {
  pir::AttributeMap TranslateOpAttribute(
      pir::IrContext* ctx,
      const std::string& normalized_op_name,
      const OpAttributeInfoList& op_attr_infos,
      const OpDesc& op_desc) override {
    auto& attribute_translator = AttributeTranslator::instance();
    auto& op_normalizer = OpNameNormalizer::instance();
    pir::AttributeMap attribute_map = {};

    for (const auto& info : op_attr_infos) {
      auto legacy_attr_name =
          op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
      VLOG(10) << "[op: " << op_desc.Type()
               << "][attr] from: " << legacy_attr_name << " to: " << info.name;
      if (op_desc.HasAttr(legacy_attr_name)) {
        paddle::framework::Attribute legacy_attr =
            op_desc.GetAttr(legacy_attr_name);
        VLOG(10) << "attribute in " << op_desc.Type()
                 << " name: " << legacy_attr_name << " " << legacy_attr.index();
        pir::Attribute new_attr =
            attribute_translator(info.type_name, legacy_attr);
        if (info.name == "pad_value") {
          new_attr = pir::DoubleAttribute::get(
              ctx,
              static_cast<double>(
                  new_attr.dyn_cast<pir::FloatAttribute>().data()));
        }
        attribute_map[info.name] = new_attr;
      } else {
        this->HandleNonexistentAttribute(ctx, &attribute_map, info);
      }
    }
    return attribute_map;
  }
};

OpTranslator::OpTranslator() {
  pir::IrContext* ctx = pir::IrContext::Instance();
  ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();

  general_handler = OpTranscriber();
  special_handlers["add_n"] = AddNOpTranscriber();
  special_handlers["argsort"] = ArgsortOpTranscriber();
  special_handlers["assign"] = AssignOpTranscriber();
  special_handlers["assign_value"] = AssignValueOpTranscriber();
  special_handlers["batch_norm"] = BatchNormOpTranscriber();
  special_handlers["range"] = ArangeOpTranscriber();
  special_handlers["cast"] = CastOpTranscriber();
  special_handlers["leaky_relu"] = LeakyReLUOpTranscriber();
  special_handlers["leaky_relu_grad"] = LeakyReLUOpTranscriber();
  special_handlers["bilinear_interp"] = InterpolateOpTranscriber();
  special_handlers["bilinear_interp_grad"] = InterpolateOpTranscriber();
  special_handlers["nearest_interp"] = InterpolateOpTranscriber();
  special_handlers["nearest_interp_grad"] = InterpolateOpTranscriber();
  special_handlers["trilinear_interp"] = InterpolateOpTranscriber();
  special_handlers["trilinear_interp_grad"] = InterpolateOpTranscriber();
  special_handlers["bicubic_interp"] = InterpolateOpTranscriber();
  special_handlers["bicubic_interp_grad"] = InterpolateOpTranscriber();
  special_handlers["linear_interp"] = InterpolateOpTranscriber();
  special_handlers["linear_interp_grad"] = InterpolateOpTranscriber();
  special_handlers["bilinear_interp_v2"] = InterpolateOpTranscriber();
  special_handlers["bilinear_interp_v2_grad"] = InterpolateOpTranscriber();
  special_handlers["nearest_interp_v2"] = InterpolateOpTranscriber();
  special_handlers["nearest_interp_v2_grad"] = InterpolateOpTranscriber();
  special_handlers["trilinear_interp_v2"] = InterpolateOpTranscriber();
  special_handlers["trilinear_interp_v2_grad"] = InterpolateOpTranscriber();
  special_handlers["bicubic_interp_v2"] = InterpolateOpTranscriber();
  special_handlers["bicubic_interp_v2_grad"] = InterpolateOpTranscriber();
  special_handlers["linear_interp_v2"] = InterpolateOpTranscriber();
  special_handlers["linear_interp_v2_grad"] = InterpolateOpTranscriber();
  special_handlers["conv2d"] = Conv2dOpTranscriber();
  special_handlers["conv3d"] = Conv3dOpTranscriber();
  special_handlers["cross_entropy_with_softmax"] =
      CrossEntropyWithSoftmaxOpTranscriber();
  special_handlers["data"] = DataOpTranscriber();
  special_handlers["depthwise_conv2d"] = DepthwiseConv2dOpTranscriber();
  special_handlers["im2sequence"] = Im2sequenceOpTranscriber();
  special_handlers["feed"] = FeedOpTranscriber();
  special_handlers["fetch"] = FetchOpTranscriber();
  special_handlers["fetch_v2"] = FetchOpTranscriber();
  special_handlers["fill_constant"] = FillConstantTranscriber();
  special_handlers["fused_feedforward"] = FusedFeedForwardOpTranscriber();
  special_handlers["fused_elemwise_add_activation"] =
      FusedElemwiseAddActivationOpTranscriber();
  special_handlers["fused_elemwise_add_activation_grad"] =
      FusedElemwiseAddActivationGradOpTranscriber();
  special_handlers["fake_quantize_moving_average_abs_max"] =
      FakeQuantizeOpTranscriber();
  special_handlers["fake_channel_wise_dequantize_max_abs"] =
      FakeQuantizeOpTranscriber();
  special_handlers["fake_quantize_range_abs_max"] = FakeQuantizeOpTranscriber();
  special_handlers["fake_quantize_dequantize_moving_average_abs_max"] =
      FakeQuantizeOpTranscriber();
  special_handlers["grad_add"] = GradAddOpTranscriber();
  special_handlers["increment"] = IncrementOpTranscriber();
  special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber();
  special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber();
  special_handlers["lookup_table"] = EmbeddingOpTranscriber();
  special_handlers["one_hot_v2"] = OneHotTranscriber();
  special_handlers["pool2d"] = Pool2dOpTranscriber();
  special_handlers["pool3d"] = Pool3dOpTranscriber();
  special_handlers["randint"] = RandIntOpTranscriber();
  special_handlers["reduce_all"] = ReduceOpTranscriber();
  special_handlers["reduce_any"] = ReduceOpTranscriber();
  special_handlers["repeat_interleave"] = RepeatInterLeaveOpTranscriber();
  special_handlers["repeat_interleave_grad"] =
      RepeatInterLeaveGradOpTranscriber();
  special_handlers["rnn"] = RnnOpTranscriber();
  special_handlers["set_value"] = LegacySetValueDispatcher();
  special_handlers["set_value_grad"] = SetValueGradOpTranscriber();
  special_handlers["shadow_output"] = ShadowOutputOpTranscriber();
  special_handlers["share_buffer"] = ShareBufferOpTranscriber();
  special_handlers["sequence_pool"] = SequencePoolOpTranscriber();
  special_handlers["dropout"] = DropoutOpTranscriber();
  special_handlers["scale"] = ScaleOpTranscriber();
  special_handlers["slice"] = SliceOpTranscriber();
  special_handlers["split"] = SplitOpTranscriber();
  special_handlers["sum"] = AddNOpTranscriber();
  special_handlers["top_p_sampling"] = TopPSamplingOpTranscriber();
  special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();
  special_handlers["tril_triu_grad"] = TrilAndTriuGradOpTranscriber();
  special_handlers["matmul"] = LegacyMatmulOpTranscriber();
  special_handlers["matrix_rank"] = MatrixRankOpTranscriber();
  special_handlers["mul"] = MulOpTranscriber();
  special_handlers["mul_grad"] = MulGradOpTranscriber();
  special_handlers["select_input"] = SelectInputOpTranscriber();
  special_handlers["select_output"] = SelectOutputOpTranscriber();
  special_handlers["softmax"] = SoftmaxOpTranscriber();
  special_handlers["softmax_with_cross_entropy"] =
      SoftmaxWithCrossEntropyOpTranscriber();
  special_handlers["gather"] = GatherOpTranscriber();
  special_handlers["box_coder"] = BoxCoderOpTranscriber();

  // To adapt DenseTensorArray
  special_handlers["lod_array_length"] = LodArrayLengthOpTranscriber();
  special_handlers["write_to_array"] = WriteArrayOpTranscriber();
  special_handlers["read_from_array"] = ReadArrayOpTranscriber();

  // special handler for elementwise ops with axis != -1
  // note(lyk): maybe we should do this by a pass, which seems more reasonable
  special_handlers["elementwise_add"] = ElementwiseTranscriber();
  special_handlers["elementwise_sub"] = ElementwiseTranscriber();
  special_handlers["elementwise_mul"] = ElementwiseTranscriber();
  special_handlers["elementwise_div"] = ElementwiseTranscriber();
  special_handlers["elementwise_max"] = ElementwiseTranscriber();
  special_handlers["elementwise_min"] = ElementwiseTranscriber();
  special_handlers["elementwise_mod"] = ElementwiseTranscriber();
  special_handlers["elementwise_floordiv"] = ElementwiseTranscriber();
  special_handlers["elementwise_add_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_sub_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_mul_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_div_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_max_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_min_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_mod_grad"] = ElementwiseGradTranscriber();
  special_handlers["elementwise_floordiv_grad"] = ElementwiseGradTranscriber();
  special_handlers["c_embedding"] = CEmbeddingOpTranscriber();
  special_handlers["quantize_linear"] = QuantizeLinearOpTranscriber();
  special_handlers["dequantize_linear"] = QuantizeLinearOpTranscriber();
  // To process Op with XShape output in old IR
  special_handlers["reshape2_grad"] =
      WithXShapeGradOpTranscriber<dialect::ReshapeGradOp>();
  special_handlers["flatten_contiguous_range_grad"] =
      WithXShapeGradOpTranscriber<dialect::FlattenGradOp>();
  special_handlers["squeeze2_grad"] =
      WithXShapeAndAxisGradOpTranscriber<dialect::SqueezeGradOp>();
  special_handlers["unsqueeze2_grad"] =
      WithXShapeAndAxisGradOpTranscriber<dialect::UnsqueezeGradOp>();

  special_handlers["c_sync_comm_stream"] = SyncCommStreamOpTranscriber();
  special_handlers["softplus"] = SoftPlusOpTranscriber();
  special_handlers["softplus_grad"] = SoftPlusOpTranscriber();
  special_handlers["logit"] = LogitOpTranscriber();
  special_handlers["logit_grad"] = LogitOpTranscriber();
  special_handlers["pad3d"] = Pad3dOpTranscriber();
  special_handlers["pad3d_grad"] = Pad3dOpTranscriber();
}
}  // namespace paddle::translator
